diff --git a/doc/lua_api.md b/doc/lua_api.md index 07f70c36f..100a75f0c 100644 --- a/doc/lua_api.md +++ b/doc/lua_api.md @@ -5971,6 +5971,9 @@ Utilities form. If the ColorSpec is invalid, returns `nil`. You can use this to parse ColorStrings. * `colorspec`: The ColorSpec to convert +* `core.colorspec_to_int(colorspec)`: Converts a ColorSpec to integer form. + If the ColorSpec is invalid, returns `nil`. + * `colorspec`: The ColorSpec to convert * `core.time_to_day_night_ratio(time_of_day)`: Returns a "day-night ratio" value (as accepted by `ObjectRef:override_day_night_ratio`) that is equivalent to the given "time of day" value (as returned by `core.get_timeofday`). @@ -5979,8 +5982,8 @@ Utilities * `width`: Width of the image * `height`: Height of the image * `data`: Image data, one of: - * array table of ColorSpec, length must be width*height - * string with raw RGBA pixels, length must be width*height*4 + * array table of ColorSpec, length must be `width * height` + * string with raw RGBA pixels, length must be `width * height * 4` * `compression`: Optional zlib compression level, number in range 0 to 9. The data is one-dimensional, starting in the upper left corner of the image and laid out in scanlines going from left to right, then top to bottom. @@ -7660,6 +7663,49 @@ Misc. * Example: `deserialize('print("foo")')`, returns `nil` (function call fails), returns `error:[string "print("foo")"]:1: attempt to call global 'print' (a nil value)` +* `core.encode_network(format, ...)`: Encodes numbers and strings in binary + format suitable for network transfer according to a format string. + * Each character in the format string corresponds to an argument to the + function. Possible format characters: + * `b`: Signed 8-bit integer + * `h`: Signed 16-bit integer + * `i`: Signed 32-bit integer + * `l`: Signed 64-bit integer + * `B`: Unsigned 8-bit integer + * `H`: Unsigned 16-bit integer + * `I`: Unsigned 32-bit integer + * `L`: Unsigned 64-bit integer + * `f`: Single-precision floating point number + * `s`: 16-bit size-prefixed string. Max 64 KB in size + * `S`: 32-bit size-prefixed string. Max 64 MB in size + * `z`: Null-terminated string. Cannot have embedded null characters + * `Z`: Verbatim string with no size or terminator + * ` `: Spaces are ignored + * Integers are encoded in big-endian format, and floating point numbers are + encoded in IEEE-754 format. Note that the full range of 64-bit integers + cannot be represented in Lua's doubles. + * If integers outside of the range of the corresponding type are encoded, + integer wraparound will occur. + * If a string that is too long for a size-prefixed string is encoded, it + will be truncated. + * If a string with an embedded null character is encoded as a null + terminated string, it is truncated to the first null character. + * Verbatim strings are added directly to the output as-is and can therefore + have any size or contents, but the code on the decoding end cannot + automatically detect its length. +* `core.decode_network(format, data, ...)`: Decodes numbers and strings from + binary format made by `core.encode_network()` according to a format string. + * The format string follows the same rules as `core.encode_network()`. + The decoded values are returned as individual values from the function. + * `Z` has special behavior; an extra argument has to be passed to the + function for every `Z` specifier denoting how many characters to read. + To read all remaining characters, use a size of `-1`. + * If the end of the data is encountered while still reading values from the + string, values of the correct type will still be returned, but strings of + variable length will be truncated, and numbers and verbatim strings will + use zeros for the missing bytes. + * If a size-prefixed string has a size that is greater than the maximum, it + will be truncated and the rest of the characters skipped. * `core.compress(data, method, ...)`: returns `compressed_data` * Compress a string of data. * `method` is a string identifying the compression method to be used. diff --git a/games/devtest/mods/unittests/misc.lua b/games/devtest/mods/unittests/misc.lua index 28cc2c1eb..132cc02ba 100644 --- a/games/devtest/mods/unittests/misc.lua +++ b/games/devtest/mods/unittests/misc.lua @@ -341,3 +341,202 @@ local function test_ipc_poll(cb) print("delta: " .. (core.get_us_time() - t0) .. "us") end unittests.register("test_ipc_poll", test_ipc_poll) + +unittests.register("test_encode_network", function() + -- 8-bit integers + assert(core.encode_network("bbbbbbb", 0, 1, -1, -128, 127, 255, 256) == + "\x00\x01\xFF\x80\x7F\xFF\x00") + assert(core.encode_network("BBBBBBB", 0, 1, -1, -128, 127, 255, 256) == + "\x00\x01\xFF\x80\x7F\xFF\x00") + + -- 16-bit integers + assert(core.encode_network("hhhhhhhh", + 0, 1, 257, -1, + -32768, 32767, 65535, 65536) == + "\x00\x00".."\x00\x01".."\x01\x01".."\xFF\xFF".. + "\x80\x00".."\x7F\xFF".."\xFF\xFF".."\x00\x00") + assert(core.encode_network("HHHHHHHH", + 0, 1, 257, -1, + -32768, 32767, 65535, 65536) == + "\x00\x00".."\x00\x01".."\x01\x01".."\xFF\xFF".. + "\x80\x00".."\x7F\xFF".."\xFF\xFF".."\x00\x00") + + -- 32-bit integers + assert(core.encode_network("iiiiiiii", + 0, 257, 2^24-1, -1, + -2^31, 2^31-1, 2^32-1, 2^32) == + "\x00\x00\x00\x00".."\x00\x00\x01\x01".."\x00\xFF\xFF\xFF".."\xFF\xFF\xFF\xFF".. + "\x80\x00\x00\x00".."\x7F\xFF\xFF\xFF".."\xFF\xFF\xFF\xFF".."\x00\x00\x00\x00") + assert(core.encode_network("IIIIIIII", + 0, 257, 2^24-1, -1, + -2^31, 2^31-1, 2^32-1, 2^32) == + "\x00\x00\x00\x00".."\x00\x00\x01\x01".."\x00\xFF\xFF\xFF".."\xFF\xFF\xFF\xFF".. + "\x80\x00\x00\x00".."\x7F\xFF\xFF\xFF".."\xFF\xFF\xFF\xFF".."\x00\x00\x00\x00") + + -- 64-bit integers + assert(core.encode_network("llllll", + 0, 1, + 511, -1, + 2^53-1, -2^53) == + "\x00\x00\x00\x00\x00\x00\x00\x00".."\x00\x00\x00\x00\x00\x00\x00\x01".. + "\x00\x00\x00\x00\x00\x00\x01\xFF".."\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".. + "\x00\x1F\xFF\xFF\xFF\xFF\xFF\xFF".."\xFF\xE0\x00\x00\x00\x00\x00\x00") + assert(core.encode_network("LLLLLL", + 0, 1, + 511, -1, + 2^53-1, -2^53) == + "\x00\x00\x00\x00\x00\x00\x00\x00".."\x00\x00\x00\x00\x00\x00\x00\x01".. + "\x00\x00\x00\x00\x00\x00\x01\xFF".."\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".. + "\x00\x1F\xFF\xFF\xFF\xFF\xFF\xFF".."\xFF\xE0\x00\x00\x00\x00\x00\x00") + + -- Strings + local max_16 = string.rep("*", 2^16 - 1) + local max_32 = string.rep("*", 2^26) + + assert(core.encode_network("ssss", + "", "hello", + max_16, max_16.."too long") == + "\x00\x00".. "\x00\x05hello".. + "\xFF\xFF"..max_16.."\xFF\xFF"..max_16) + assert(core.encode_network("SSSS", + "", "hello", + max_32, max_32.."too long") == + "\x00\x00\x00\x00".. "\x00\x00\x00\x05hello".. + "\x04\x00\x00\x00"..max_32.."\x04\x00\x00\x00"..max_32) + assert(core.encode_network("zzzz", + "", "hello", "hello\0embedded", max_16.."longer") == + "\0".."hello\0".."hello\0".. max_16.."longer\0") + assert(core.encode_network("ZZZZ", + "", "hello", "hello\0embedded", max_16.."longer") == + "".."hello".."hello\0embedded"..max_16.."longer") + + -- Spaces + assert(core.encode_network("B I", 255, 2^31) == "\xFF\x80\x00\x00\x00") + assert(core.encode_network(" B Zz ", 15, "abc", "xyz") == "\x0Fabcxyz\0") + + -- Empty format strings + assert(core.encode_network("") == "") + assert(core.encode_network(" ", 5, "extra args") == "") +end) + +unittests.register("test_decode_network", function() + local d + + -- 8-bit integers + d = {core.decode_network("bbbbb", "\x00\x01\x7F\x80\xFF")} + assert(#d == 5) + assert(d[1] == 0 and d[2] == 1 and d[3] == 127 and d[4] == -128 and d[5] == -1) + + d = {core.decode_network("BBBBB", "\x00\x01\x7F\x80\xFF")} + assert(#d == 5) + assert(d[1] == 0 and d[2] == 1 and d[3] == 127 and d[4] == 128 and d[5] == 255) + + -- 16-bit integers + d = {core.decode_network("hhhhhh", + "\x00\x00".."\x00\x01".."\x01\x01".. + "\x7F\xFF".."\x80\x00".."\xFF\xFF")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 257 and + d[4] == 32767 and d[5] == -32768 and d[6] == -1) + + d = {core.decode_network("HHHHHH", + "\x00\x00".."\x00\x01".."\x01\x01".. + "\x7F\xFF".."\x80\x00".."\xFF\xFF")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 257 and + d[4] == 32767 and d[5] == 32768 and d[6] == 65535) + + -- 32-bit integers + d = {core.decode_network("iiiiii", + "\x00\x00\x00\x00".."\x00\x00\x00\x01".."\x00\xFF\xFF\xFF".. + "\x7F\xFF\xFF\xFF".."\x80\x00\x00\x00".."\xFF\xFF\xFF\xFF")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 2^24-1 and + d[4] == 2^31-1 and d[5] == -2^31 and d[6] == -1) + + d = {core.decode_network("IIIIII", + "\x00\x00\x00\x00".."\x00\x00\x00\x01".."\x00\xFF\xFF\xFF".. + "\x7F\xFF\xFF\xFF".."\x80\x00\x00\x00".."\xFF\xFF\xFF\xFF")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 2^24-1 and + d[4] == 2^31-1 and d[5] == 2^31 and d[6] == 2^32-1) + + -- 64-bit integers + d = {core.decode_network("llllll", + "\x00\x00\x00\x00\x00\x00\x00\x00".."\x00\x00\x00\x00\x00\x00\x00\x01".. + "\x00\x00\x00\x00\x00\x00\x01\xFF".."\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".. + "\x00\x1F\xFF\xFF\xFF\xFF\xFF\xFF".."\xFF\xE0\x00\x00\x00\x00\x00\x00")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 511 and + d[4] == -1 and d[5] == 2^53-1 and d[6] == -2^53) + + d = {core.decode_network("LLLLLL", + "\x00\x00\x00\x00\x00\x00\x00\x00".."\x00\x00\x00\x00\x00\x00\x00\x01".. + "\x00\x00\x00\x00\x00\x00\x01\xFF".."\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".. + "\x00\x1F\xFF\xFF\xFF\xFF\xFF\xFF".."\xFF\xE0\x00\x00\x00\x00\x00\x00")} + assert(#d == 6) + assert(d[1] == 0 and d[2] == 1 and d[3] == 511 and + d[4] == 2^64-1 and d[5] == 2^53-1 and d[6] == 2^64 - 2^53) + + -- Floating point numbers + local enc = core.encode_network("fff", + 0.0, 123.456, -987.654) + assert(#enc == 3 * 4) + + d = {core.decode_network("fff", enc)} + assert(#d == 3) + assert(d[1] == 0.0 and d[2] > 123.45 and d[2] < 123.46 and + d[3] > -987.66 and d[3] < -987.65) + + -- Strings + local max_16 = string.rep("*", 2^16 - 1) + local max_32 = string.rep("*", 2^26) + + d = {core.decode_network("ssss", + "\x00\x00".."\x00\x05hello".."\xFF\xFF"..max_16.."\x00\xFFtoo short")} + assert(#d == 4) + assert(d[1] == "" and d[2] == "hello" and d[3] == max_16 and d[4] == "too short") + + d = {core.decode_network("SSSSS", + "\x00\x00\x00\x00".."\x00\x00\x00\x05hello".. + "\x04\x00\x00\x00"..max_32.."\x04\x00\x00\x08"..max_32.."too long".. + "\x00\x00\x00\xFFtoo short")} + assert(#d == 5) + assert(d[1] == "" and d[2] == "hello" and + d[3] == max_32 and d[4] == max_32 and d[5] == "too short") + + d = {core.decode_network("zzzz", "\0".."hello\0".."missing end")} + assert(#d == 4) + assert(d[1] == "" and d[2] == "hello" and d[3] == "missing end" and d[4] == "") + + -- Verbatim strings + d = {core.decode_network("ZZZZ", "xxxyyyyyzzz", 3, 0, 5, -1)} + assert(#d == 4) + assert(d[1] == "xxx" and d[2] == "" and d[3] == "yyyyy" and d[4] == "zzz") + + -- Read past end + d = {core.decode_network("bhilBHILf", "")} + assert(#d == 9) + assert(d[1] == 0 and d[2] == 0 and d[3] == 0 and d[4] == 0 and + d[5] == 0 and d[6] == 0 and d[7] == 0 and d[8] == 0 and d[9] == 0.0) + + d = {core.decode_network("ZsSzZ", "xx", 4, 4)} + assert(#d == 5) + assert(d[1] == "xx\0\0" and d[2] == "" and d[3] == "" and + d[4] == "" and d[5] == "\0\0\0\0") + + -- Spaces + d = {core.decode_network("B I", "\xFF\x80\x00\x00\x00")} + assert(#d == 2) + assert(d[1] == 255 and d[2] == 2^31) + + d = {core.decode_network(" B Zz ", "\x0Fabcxyz\0", 3)} + assert(#d == 3) + assert(d[1] == 15 and d[2] == "abc" and d[3] == "xyz") + + -- Empty format strings + d = {core.decode_network("", "some random data")} + assert(#d == 0) + d = {core.decode_network(" ", "some random data", 3, 5)} + assert(#d == 0) +end) diff --git a/src/script/lua_api/l_util.cpp b/src/script/lua_api/l_util.cpp index 5ac290b2e..740c6b2e4 100644 --- a/src/script/lua_api/l_util.cpp +++ b/src/script/lua_api/l_util.cpp @@ -585,10 +585,10 @@ int ModApiUtil::l_colorspec_to_colorstring(lua_State *L) snprintf(colorstring, 10, "#%02X%02X%02X%02X", color.getRed(), color.getGreen(), color.getBlue(), color.getAlpha()); lua_pushstring(L, colorstring); - return 1; + } else { + lua_pushnil(L); } - - return 0; + return 1; } // colorspec_to_bytes(colorspec) @@ -605,10 +605,10 @@ int ModApiUtil::l_colorspec_to_bytes(lua_State *L) (u8) color.getAlpha(), }; lua_pushlstring(L, (const char*) colorbytes, 4); - return 1; + } else { + lua_pushnil(L); } - - return 0; + return 1; } // colorspec_to_table(colorspec) @@ -619,10 +619,201 @@ int ModApiUtil::l_colorspec_to_table(lua_State *L) video::SColor color(0); if (read_color(L, 1, &color)) { push_ARGB8(L, color); - return 1; + } else { + lua_pushnil(L); + } + return 1; +} + +// colorspec_to_int(colorspec) +int ModApiUtil::l_colorspec_to_int(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + + video::SColor color(0); + if (read_color(L, 1, &color)) { + lua_pushnumber(L, color.color); + } else { + lua_pushnil(L); + } + return 1; +} + +// encode_network(format, ...) +int ModApiUtil::l_encode_network(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + + std::string format = readParam(L, 1); + std::ostringstream os(std::ios_base::binary); + + int arg = 2; + for (size_t i = 0; i < format.size(); i++) { + switch (format[i]) { + case 'b': + // Casting the double to a signed integer larger than the target + // integer results in proper integer wraparound behavior. + writeS8(os, (s64)luaL_checknumber(L, arg)); + break; + case 'h': + writeS16(os, (s64)luaL_checknumber(L, arg)); + break; + case 'i': + writeS32(os, (s64)luaL_checknumber(L, arg)); + break; + case 'l': + writeS64(os, (s64)luaL_checknumber(L, arg)); + break; + case 'B': + // Casting to an unsigned integer doesn't result in the proper + // integer conversions being applied, so we still use signed. + writeU8(os, (s64)luaL_checknumber(L, arg)); + break; + case 'H': + writeU16(os, (s64)luaL_checknumber(L, arg)); + break; + case 'I': + writeU32(os, (s64)luaL_checknumber(L, arg)); + break; + case 'L': + // For the 64-bit integers, we can never experience integer + // overflow due to the limited range of Lua's doubles, but we can + // have underflow, hence why we cast to s64 first. + writeU64(os, (s64)luaL_checknumber(L, arg)); + break; + case 'f': + writeF32(os, luaL_checknumber(L, arg)); + break; + case 's': { + std::string str = readParam(L, arg); + os << serializeString16(str, true); + break; + } + case 'S': { + std::string str = readParam(L, arg); + os << serializeString32(str, true); + break; + } + case 'z': { + std::string str = readParam(L, arg); + os << std::string_view(str.c_str(), strlen(str.c_str())) << '\0'; + break; + } + case 'Z': + os << readParam(L, arg); + break; + case ' ': + // Continue because we don't want to increment arg. + continue; + default: + throw LuaError("Invalid format string"); + } + + arg++; } - return 0; + std::string data = os.str(); + lua_pushlstring(L, data.c_str(), data.size()); + return 1; +} + +// decode_network(format, data) +int ModApiUtil::l_decode_network(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + + std::string format = readParam(L, 1); + std::string data = readParam(L, 2); + std::istringstream is(data, std::ios_base::binary); + + // Make sure we have space for all our returned arguments. + lua_checkstack(L, format.size()); + + // Set up tracking for verbatim strings and the number of return values. + int num_args = lua_gettop(L); + int arg = 3; + int ret = 0; + + for (size_t i = 0; i < format.size(); i++) { + switch (format[i]) { + case 'b': + lua_pushnumber(L, readS8(is)); + break; + case 'h': + lua_pushnumber(L, readS16(is)); + break; + case 'i': + lua_pushnumber(L, readS32(is)); + break; + case 'l': + lua_pushnumber(L, readS64(is)); + break; + case 'B': + lua_pushnumber(L, readU8(is)); + break; + case 'H': + lua_pushnumber(L, readU16(is)); + break; + case 'I': + lua_pushnumber(L, readU32(is)); + break; + case 'L': + lua_pushnumber(L, readU64(is)); + break; + case 'f': + lua_pushnumber(L, readF32(is)); + break; + case 's': { + std::string str = deSerializeString16(is, true); + lua_pushlstring(L, str.c_str(), str.size()); + break; + } + case 'S': { + std::string str = deSerializeString32(is, true); + lua_pushlstring(L, str.c_str(), str.size()); + break; + } + case 'z': { + std::string str; + std::getline(is, str, '\0'); + + lua_pushlstring(L, str.c_str(), str.size()); + break; + } + case 'Z': { + if (arg > num_args) { + throw LuaError("Missing verbatim string size"); + } + + double size = luaL_checknumber(L, arg); + std::string str; + + if (size < 0) { + // Read the entire rest of the input stream. + std::ostringstream os(std::ios_base::binary); + os << is.rdbuf(); + str = os.str(); + } else if (size != 0) { + // Read the specified number of characters. + str.resize(size); + is.read(&str[0], size); + } + + lua_pushlstring(L, str.c_str(), str.size()); + arg++; + break; + } + case ' ': + // Continue because we don't want to increment ret. + continue; + default: + throw LuaError("Invalid format string"); + } + + ret++; + } + + return ret; } // time_to_day_night_ratio(time_of_day) @@ -737,8 +928,12 @@ void ModApiUtil::Initialize(lua_State *L, int top) API_FCT(colorspec_to_colorstring); API_FCT(colorspec_to_bytes); API_FCT(colorspec_to_table); - API_FCT(time_to_day_night_ratio); + API_FCT(colorspec_to_int); + API_FCT(encode_network); + API_FCT(decode_network); + + API_FCT(time_to_day_night_ratio); API_FCT(encode_png); API_FCT(get_last_run_mod); @@ -774,6 +969,11 @@ void ModApiUtil::InitializeClient(lua_State *L, int top) API_FCT(colorspec_to_colorstring); API_FCT(colorspec_to_bytes); API_FCT(colorspec_to_table); + API_FCT(colorspec_to_int); + + API_FCT(encode_network); + API_FCT(decode_network); + API_FCT(time_to_day_night_ratio); API_FCT(get_last_run_mod); @@ -820,8 +1020,12 @@ void ModApiUtil::InitializeAsync(lua_State *L, int top) API_FCT(colorspec_to_colorstring); API_FCT(colorspec_to_bytes); API_FCT(colorspec_to_table); - API_FCT(time_to_day_night_ratio); + API_FCT(colorspec_to_int); + API_FCT(encode_network); + API_FCT(decode_network); + + API_FCT(time_to_day_night_ratio); API_FCT(encode_png); API_FCT(get_last_run_mod); diff --git a/src/script/lua_api/l_util.h b/src/script/lua_api/l_util.h index 0df2c3ae4..3686255a6 100644 --- a/src/script/lua_api/l_util.h +++ b/src/script/lua_api/l_util.h @@ -5,6 +5,7 @@ #pragma once #include "lua_api/l_base.h" +#include "util/serialize.h" class AsyncEngine; @@ -110,6 +111,15 @@ private: // colorspec_to_table(colorspec) static int l_colorspec_to_table(lua_State *L); + // colorspec_to_int(colorspec) + static int l_colorspec_to_int(lua_State *L); + + // encode_network(format, ...) + static int l_encode_network(lua_State *L); + + // decode_network(format, data) + static int l_decode_network(lua_State *L); + // time_to_day_night_ratio(time_of_day) static int l_time_to_day_night_ratio(lua_State *L); diff --git a/src/util/serialize.cpp b/src/util/serialize.cpp index e7a002662..179e84104 100644 --- a/src/util/serialize.cpp +++ b/src/util/serialize.cpp @@ -19,40 +19,56 @@ FloatType g_serialize_f32_type = FLOATTYPE_UNKNOWN; //// String //// -std::string serializeString16(std::string_view plain) +std::string serializeString16(std::string_view plain, bool truncate) { std::string s; - char buf[2]; + size_t size = plain.size(); static_assert(STRING_MAX_LEN <= U16_MAX); - if (plain.size() > STRING_MAX_LEN) - throw SerializationError("String too long for serializeString16"); - s.reserve(2 + plain.size()); - writeU16((u8 *)&buf[0], plain.size()); - s.append(buf, 2); + if (size > STRING_MAX_LEN) { + if (truncate) { + size = STRING_MAX_LEN; + } else { + throw SerializationError("String too long for serializeString16"); + } + } + + char size_buf[2]; + writeU16((u8 *)size_buf, size); + + s.reserve(2 + size); + s.append(size_buf, 2); + s.append(plain.substr(0, size)); - s.append(plain); return s; } -std::string deSerializeString16(std::istream &is) +std::string deSerializeString16(std::istream &is, bool truncate) { std::string s; - char buf[2]; + char size_buf[2]; - is.read(buf, 2); - if (is.gcount() != 2) + is.read(size_buf, 2); + if (is.gcount() != 2) { + if (truncate) { + return s; + } throw SerializationError("deSerializeString16: size not read"); + } - u16 s_size = readU16((u8 *)buf); - if (s_size == 0) + u16 size = readU16((u8 *)size_buf); + if (size == 0) { return s; + } - s.resize(s_size); - is.read(&s[0], s_size); - if (is.gcount() != s_size) + s.resize(size); + is.read(&s[0], size); + if (truncate) { + s.resize(is.gcount()); + } else if (is.gcount() != size) { throw SerializationError("deSerializeString16: truncated"); + } return s; } @@ -62,45 +78,74 @@ std::string deSerializeString16(std::istream &is) //// Long String //// -std::string serializeString32(std::string_view plain) +std::string serializeString32(std::string_view plain, bool truncate) { std::string s; - char buf[4]; + size_t size = plain.size(); static_assert(LONG_STRING_MAX_LEN <= U32_MAX); - if (plain.size() > LONG_STRING_MAX_LEN) - throw SerializationError("String too long for serializeLongString"); - s.reserve(4 + plain.size()); - writeU32((u8*)&buf[0], plain.size()); - s.append(buf, 4); - s.append(plain); + if (size > LONG_STRING_MAX_LEN) { + if (truncate) { + size = LONG_STRING_MAX_LEN; + } else { + throw SerializationError("String too long for serializeString32"); + } + } + + char size_buf[4]; + writeU32((u8 *)size_buf, size); + + s.reserve(4 + size); + s.append(size_buf, 4); + s.append(plain.substr(0, size)); + return s; } -std::string deSerializeString32(std::istream &is) +std::string deSerializeString32(std::istream &is, bool truncate) { std::string s; - char buf[4]; + char size_buf[4]; - is.read(buf, 4); - if (is.gcount() != 4) - throw SerializationError("deSerializeLongString: size not read"); - - u32 s_size = readU32((u8 *)buf); - if (s_size == 0) - return s; - - // We don't really want a remote attacker to force us to allocate 4GB... - if (s_size > LONG_STRING_MAX_LEN) { - throw SerializationError("deSerializeLongString: " - "string too long: " + itos(s_size) + " bytes"); + is.read(size_buf, 4); + if (is.gcount() != 4) { + if (truncate) { + return s; + } + throw SerializationError("deSerializeString32: size not read"); } - s.resize(s_size); - is.read(&s[0], s_size); - if ((u32)is.gcount() != s_size) - throw SerializationError("deSerializeLongString: truncated"); + u32 size = readU32((u8 *)size_buf); + u32 ignore = 0; + if (size == 0) { + return s; + } + + if (size > LONG_STRING_MAX_LEN) { + if (truncate) { + ignore = size - LONG_STRING_MAX_LEN; + size = LONG_STRING_MAX_LEN; + } else { + // We don't really want a remote attacker to force us to allocate 4GB... + throw SerializationError("deSerializeString32: " + "string too long: " + itos(size) + " bytes"); + } + } + + s.resize(size); + is.read(&s[0], size); + if (truncate) { + s.resize(is.gcount()); + } else if (is.gcount() != size) { + throw SerializationError("deSerializeString32: truncated"); + } + + // If the string was truncated due to exceeding the string max length, we + // need to ignore the rest of the characters. + if (truncate) { + is.seekg(ignore, std::ios_base::cur); + } return s; } diff --git a/src/util/serialize.h b/src/util/serialize.h index 7da5f44d6..1edce6f79 100644 --- a/src/util/serialize.h +++ b/src/util/serialize.h @@ -446,16 +446,16 @@ MAKE_STREAM_WRITE_FXN(video::SColor, ARGB8, 4); } // Creates a string with the length as the first two bytes -std::string serializeString16(std::string_view plain); +std::string serializeString16(std::string_view plain, bool truncate = false); // Reads a string with the length as the first two bytes -std::string deSerializeString16(std::istream &is); +std::string deSerializeString16(std::istream &is, bool truncate = false); // Creates a string with the length as the first four bytes -std::string serializeString32(std::string_view plain); +std::string serializeString32(std::string_view plain, bool truncate = false); // Reads a string with the length as the first four bytes -std::string deSerializeString32(std::istream &is); +std::string deSerializeString32(std::istream &is, bool truncate = false); // Creates a string encoded in JSON format (almost equivalent to a C string literal) std::string serializeJsonString(std::string_view plain);