diff --git a/builtin/common/metatable.lua b/builtin/common/metatable.lua index debc4d3c5..ccfa7c5c3 100644 --- a/builtin/common/metatable.lua +++ b/builtin/common/metatable.lua @@ -1,14 +1,28 @@ -- Registered metatables, used by the C++ packer +local serializable_metatables = {} local known_metatables = {} -function core.register_portable_metatable(name, mt) + +local function dummy_serializer(x) + return x +end + +function core.register_portable_metatable(name, mt, serializer, deserializer) + serializer = serializer or dummy_serializer + deserializer = deserializer or function(x) return setmetatable(x, mt) end assert(type(name) == "string", ("attempt to use %s value as metatable name"):format(type(name))) assert(type(mt) == "table", ("attempt to register a %s value as metatable"):format(type(mt))) + assert(type(serializer), ("attempt to use a %s value as serializer"):format(type(serializer))) + assert(type(deserializer), ("attempt to use a %s value as serialier"):format(type(deserializer))) assert(known_metatables[name] == nil or known_metatables[name] == mt, ("attempt to override metatable %s"):format(name)) known_metatables[name] = mt known_metatables[mt] = name + serializable_metatables[mt] = serializer + serializable_metatables[name] = deserializer end + core.known_metatables = known_metatables +core.serializable_metatables = serializable_metatables function core.register_async_metatable(...) core.log("deprecated", "core.register_async_metatable is deprecated. " .. @@ -17,3 +31,9 @@ function core.register_async_metatable(...) end core.register_portable_metatable("__builtin:vector", vector.metatable) + +if ItemStack then + local item = ItemStack() + local itemstack_mt = getmetatable(item) + core.register_portable_metatable("__itemstack", itemstack_mt, item.to_table, ItemStack) +end diff --git a/builtin/common/serialize.lua b/builtin/common/serialize.lua index 9ebece6d0..b4ebeff7b 100644 --- a/builtin/common/serialize.lua +++ b/builtin/common/serialize.lua @@ -8,22 +8,38 @@ local next, rawget, pairs, pcall, error, type, setfenv, loadstring local table_concat, string_dump, string_format, string_match, math_huge = table.concat, string.dump, string.format, string.match, math.huge --- Recursively counts occurrences of objects (non-primitives including strings) in a table. -local function count_objects(value) +local function pack_args(...) + return {n = select("#", ...), ...} +end + +-- Recursively +-- (1) reads metatables from tables; +-- (2) counts occurrences of objects (non-primitives including strings) in a table. +local function prepare_objects(value) local counts = {} + local type_lookup = {} if value == nil then -- Early return for nil; tables can't contain nil - return counts + return counts, type_lookup end - local function count_values(val) + local function count_values(val, recount) local type_ = type(val) if type_ == "boolean" or type_ == "number" then return end local count = counts[val] - counts[val] = (count or 0) + 1 - if type_ == "table" then - if not count then + if not recount then + counts[val] = (count or 0) + 1 + end + local mt = (not count) and (type_ == "table" or type_ == "userdata") and getmetatable(val) + if mt and core.serializable_metatables[mt] then + local args = pack_args(core.known_metatables[mt], core.serializable_metatables[mt](val)) + type_lookup[val] = args + for _, v in ipairs(args) do + count_values(v, rawequal(v, val)) + end + elseif type_ == "table" then + if recount or not count then for k, v in pairs(val) do count_values(k) count_values(v) @@ -34,7 +50,7 @@ local function count_objects(value) end end count_values(value) - return counts + return counts, type_lookup end -- Build a "set" of Lua keywords. These can't be used as short key names. @@ -66,7 +82,15 @@ local function serialize(value, write) local references = {} -- Circular tables that must be filled using `table[key] = value` statements local to_fill = {} - for object, count in pairs(count_objects(value)) do + local counts, typeinfo = prepare_objects(value) + if next(typeinfo) then + write [[ + if not (core and core.serializable_metatables) then + core = { known_metatables = {}, serializable_metatables = {}} + end; + ]] + end + for object, count in pairs(counts) do local type_ = type(object) -- Object must appear more than once. If it is a string, the reference has to be shorter than the string. if count >= 2 and (type_ ~= "string" or #reference + 5 < #object) then @@ -96,7 +120,22 @@ local function serialize(value, write) local function use_short_key(key) return not references[key] and type(key) == "string" and (not keywords[key]) and string_match(key, "^[%a_][%a%d_]*$") end - local function dump(value) + local dump + local function dump_serialized(value) + local serialized = assert(typeinfo[value]) + write "(core.serializable_metatables[" + dump(serialized[1]) + write "])(" + for k = 2, serialized.n do + if k ~= 2 then + write "," + end + local v = serialized[k] + dump(v, rawequal(v, value)) + end + write ")" + end + dump = function(value, skip_mt) -- Primitive types if value == nil then return write("nil") @@ -126,6 +165,10 @@ local function serialize(value, write) write(ref) return write"]" end + if (not skip_mt) and typeinfo[value] then + dump_serialized(value) + return + end if type_ == "string" then return write(quote(value)) end @@ -168,8 +211,8 @@ local function serialize(value, write) end end -- Write the statements to fill circular tables - for table, ref in pairs(to_fill) do - for k, v in pairs(table) do + for tbl, ref in pairs(to_fill) do + for k, v in pairs(tbl) do write("_[") write(ref) write("]") @@ -185,6 +228,13 @@ local function serialize(value, write) dump(v) write(";") end + if typeinfo[tbl] then + write("_[") + write(ref) + write("]=") + dump_serialized(tbl) + write(";") + end end write("return ") dump(value) @@ -246,7 +296,14 @@ function core.deserialize(str, safe) if not func then return nil, err end -- math.huge was serialized to inf and NaNs to nan by Lua in engine version 5.6, so we have to support this here - local env = {inf = math_huge, nan = 0/0} + local env = { + inf = math_huge, + nan = 0/0, + core = { + known_metatables = core.known_metatables, + serializable_metatables = core.serializable_metatables, + }, + } if safe then env.loadstring = dummy_func else @@ -266,3 +323,4 @@ function core.deserialize(str, safe) end return nil, value_or_err end + diff --git a/builtin/common/tests/serialize_spec.lua b/builtin/common/tests/serialize_spec.lua index 2a7a0f3ce..1c8997b16 100644 --- a/builtin/common/tests/serialize_spec.lua +++ b/builtin/common/tests/serialize_spec.lua @@ -4,6 +4,7 @@ _G.setfenv = require 'busted.compatibility'.setfenv dofile("builtin/common/serialize.lua") dofile("builtin/common/vector.lua") +dofile("builtin/common/metatable.lua") -- Supports circular tables; does not support table keys -- Correctly checks whether a mapping of references ("same") exists @@ -40,11 +41,32 @@ local t1, t2 = {x, x, y, y}, {x, y, x, y} assert.same(t1, t2) -- will succeed because it only checks whether the depths match assert(not pcall(assert_same, t1, t2)) -- will correctly fail because it checks whether the refs match +local pair_mt = { + __eq = function(x, y) + return x[1] == y[1] and x[2] == y[2] + end, +} +local function pair(x, y) + return setmetatable({x, y}, pair_mt) +end +-- Use our own serialization functions to avoid incorrectly passing test related to references. +core.register_portable_metatable("pair", pair_mt) +assert.equals(pair(1, 2), pair(1, 2)) +assert.not_equals(pair(1, 2), pair(3, 4)) + describe("serialize", function() local function assert_preserves(value) local preserved_value = core.deserialize(core.serialize(value)) assert_same(value, preserved_value) end + local function assert_strictly_preserves(value) + local preserved_value = core.deserialize(core.serialize(value)) + assert.equals(value, preserved_value) + end + local function assert_compatibly_preserves(value) + local preserved_value = loadstring(core.serialize(value))() + assert_same(value, preserved_value) + end it("works", function() assert_preserves({cat={sound="nyan", speed=400}, dog={sound="woof"}}) end) @@ -53,6 +75,10 @@ describe("serialize", function() assert_preserves({escape_chars="\n\r\t\v\\\"\'", non_european="θשׁ٩∂"}) end) + it("handles nil", function() + assert_strictly_preserves(nil) + end) + it("handles NaN & infinities", function() local nan = core.deserialize(core.serialize(0/0)) assert(nan ~= nan) @@ -141,7 +167,10 @@ describe("serialize", function() it("vectors work", function() local v = vector.new(1, 2, 3) assert_preserves({v}) - assert_preserves(v) + assert_compatibly_preserves({v}) + assert_strictly_preserves(v) + assert_compatibly_preserves(v) + assert(core.deserialize(core.serialize(v)):check()) -- abuse v = vector.new(1, 2, 3) @@ -149,6 +178,43 @@ describe("serialize", function() assert_preserves(v) end) + it("correctly handles typed objects with multiple references", function() + local x, y = pair(1, 2), pair(1, 2) + local t = core.deserialize(core.serialize{x, x, y}) + assert.equals(x, t[1]) + assert.equals(x, t[3]) + assert(rawequal(t[1], t[2])) + assert(not rawequal(t[1], t[3])) + end) + + it("correctly handles recursive typed objects with the identity function as serializer", function() + local mt = { + __eq = function(x, y) + return x[1] == y[1] + end, + } + core.register_portable_metatable("test_recursive_typed", mt) + local t = setmetatable({1}, mt) + t[2] = t + assert_strictly_preserves(t) + end) + + it("correctly handles binary trees", function() + local child = {pair(1, 1)} + local layers = 4 + for i = 2, layers do + child[i] = pair(child[i-1], child[i-1]) + end + local tree = child[layers] + assert_strictly_preserves(tree) + local node = core.deserialize(core.serialize(tree)) + for i = 2, layers do + assert(rawequal(node[1], node[2])) + node = node[1] + end + assert_compatibly_preserves(tree) + end) + it("handles keywords as keys", function() assert_preserves({["and"] = "keyword", ["for"] = "keyword"}) end) diff --git a/builtin/common/vector.lua b/builtin/common/vector.lua index 7a8558cbd..bbe4e6ad4 100644 --- a/builtin/common/vector.lua +++ b/builtin/common/vector.lua @@ -12,6 +12,10 @@ vector = {} local metatable = {} vector.metatable = metatable +if core and core.register_serializable then + core.register_serializable("__builtin:vector", metatable) +end + local xyz = {"x", "y", "z"} -- only called when rawget(v, key) returns nil diff --git a/doc/lua_api.md b/doc/lua_api.md index b34e571c2..2515c56f0 100644 --- a/doc/lua_api.md +++ b/doc/lua_api.md @@ -7812,12 +7812,18 @@ Misc. * `core.global_exists(name)` * Checks if a global variable has been set, without triggering a warning. -* `core.register_portable_metatable(name, mt)`: +* `core.register_portable_metatable(name, mt, serializer, deserializer)`: * Register a metatable that should be preserved when Lua data is transferred - between environments (via IPC or `handle_async`). + between environments (via IPC, `handle_async`, or `core.serialize`). * `name` is a string that identifies the metatable. It is recommended to follow the `modname:name` convention for this identifier. * `mt` is the metatable to register. + * `serializer` is a function used by `core.serialize` to serialize data with + the given metatable. It may return multiple values, but the return values should not + contain the input datum unless `serializer` is the identity function. The default + value for `serializer` is the identity function. + * `deserializer` is a function used by `core.deserialize` to deserialize data from + values returned by `serializer`. The default value is a wrapper around `setmetatable`. * Note that the same metatable can be registered under multiple names, but multiple metatables must not be registered under the same name. * You must register the metatable in both the main environment diff --git a/games/devtest/mods/unittests/itemstack_equals.lua b/games/devtest/mods/unittests/itemstack_equals.lua index 561e612c4..ff74e562b 100644 --- a/games/devtest/mods/unittests/itemstack_equals.lua +++ b/games/devtest/mods/unittests/itemstack_equals.lua @@ -72,3 +72,10 @@ local function test_itemstack_equals_metadata() end unittests.register("test_itemstack_equals_metadata", test_itemstack_equals_metadata) + +local function test_itemstack_serialization_preservation() + local i = ItemStack("basenodes:stone 20 1000") + assert(i:equals(core.deserialize(core.serialize(i)))) +end + +unittests.register("test_itemstack_serialization_preservation", test_itemstack_serialization_preservation)