From 7273b664ae346d006d184d5be3d9eab3a7744a80 Mon Sep 17 00:00:00 2001 From: Lars Mueller Date: Fri, 18 Apr 2025 03:25:15 +0200 Subject: [PATCH] Implement safe `require` --- builtin/init.lua | 66 ++++++++++++++++++++++++ doc/lua_api.md | 30 +++++++++++ games/devtest/.luacheckrc | 5 +- games/devtest/mods/unittests/init.lua | 6 +++ games/devtest/mods/unittests/misc.lua | 24 +++++++++ games/devtest/mods/unittests/require.lua | 1 + src/script/cpp_api/s_base.cpp | 48 +++++++++++------ src/script/cpp_api/s_base.h | 3 +- src/server/mods.cpp | 2 +- 9 files changed, 165 insertions(+), 20 deletions(-) create mode 100644 games/devtest/mods/unittests/require.lua diff --git a/builtin/init.lua b/builtin/init.lua index 59d1558fc..c40dd01b4 100644 --- a/builtin/init.lua +++ b/builtin/init.lua @@ -49,6 +49,72 @@ dofile(commonpath .. "serialize.lua") dofile(commonpath .. "misc_helpers.lua") if INIT == "game" then + local function mod_loader(module_name) + local parts = module_name:split(".") + local modname = parts[1] + local modpath = core.get_modpath(modname) + if not modpath then + return "no mod called " .. modname + end + parts[1] = modpath + local base_path = table.concat(parts, "/") + local errors = {} + for _, suffix in ipairs({"/init.lua", #parts > 1 and ".lua" or nil}) do + local source_path = base_path .. suffix + local f, err = io.open(source_path, "r") + if f then + local source = f:read("*all") + local chunk, load_err = loadstring(source, source_path) + f:close() + assert(chunk, load_err) + return chunk + else + table.insert(errors, err) + end + end + return table.concat(errors, "\n") + end + + local function preprocess_module_name(module_name) + if module_name:sub(1, 1) == "." then + module_name = assert(core.get_current_modname()) .. module_name + end + assert(module_name:find("^[A-Za-z0-9_%.]+$") and + module_name:find("%.%.") == nil and + module_name:sub(-1) ~= ".", + "invalid module name") + return module_name + end + + local loaded = {} -- [module_name] = function() return module end + package = { + loaders = {mod_loader}, + unload = function(module_name) + loaded[preprocess_module_name(module_name)] = nil + end, + set = function(module_name, module) + loaded[preprocess_module_name(module_name)] = function() return module end + end, + } + function require(module_name) + module_name = preprocess_module_name(module_name) + local module_func = loaded[module_name] + if module_func then + return module_func() + end + local errors = {} + for _, loader in ipairs(package.loaders) do + local res = loader(module_name) + if type(res) == "function" then + local module = res() + package.set(module_name, module) + return module + end + table.insert(errors, res) + end + error("failed to load module '" .. module_name .. '":\n' .. table.concat(errors, "\n")) + end + dofile(scriptdir .. "game" .. DIR_DELIM .. "init.lua") assert(not core.get_http_api) elseif INIT == "mainmenu" then diff --git a/doc/lua_api.md b/doc/lua_api.md index 882bfe341..c6d9a3f17 100644 --- a/doc/lua_api.md +++ b/doc/lua_api.md @@ -4121,6 +4121,36 @@ For example: * `core.dir_to_wallmounted` (Involves wallmounted param2 values.) +Loading files +============= + +The typical way to load files until 5.13.0 was to use `dofile` +and run files in order of dependencies in `init.lua`, e.g. + +```lua +local modname = core.get_current_modname() +local modpath = core.get_modpath(modname) +dofile(modname .. "/stuff.lua") +dofile(modname .. "/more_stuff.lua") +``` + +This is clunky and has several drawbacks. +As of version 5.13.0, Luanti supports a custom version of Lua's `require`: + +* `require("mymod")` gives you whatever `init.lua` of `mymod` returns. + You should have a dependency (optional or not) on `mymod` if you call this. +* `require("mod.dir.file")` loads `dir/file.lua` or `dir/file/init.lua` from the mod folder of `mod`. +* For convenience, `require(".dir.file")` is equivalent to `require(core.get_current_modname() .. ".dir.file")`. + This is also supported by the `package.*` functions. + +The implementation is customizable via the `package` table: + +* `package.loaders` is a list of loaders which are tried in order. + A loader is a `function(module_name)` which returns `nil`, + a string explaining why it couldn't load the module, + or a function that when called returns the module. +* `package.unload(module_name)` can be used to forcibly unload a module. +* `package.set(module_name, module)` can be used to override a module. Helper functions diff --git a/games/devtest/.luacheckrc b/games/devtest/.luacheckrc index 2ef36d209..d3994bb2f 100644 --- a/games/devtest/.luacheckrc +++ b/games/devtest/.luacheckrc @@ -31,8 +31,9 @@ read_globals = { "PcgRandom", string = {fields = {"split", "trim"}}, - table = {fields = {"copy", "getn", "indexof", "insert_all", "key_value_swap"}}, - math = {fields = {"hypot", "round"}}, + table = {fields = {"copy", "getn", "indexof", "insert_all", "key_value_swap"}}, + math = {fields = {"hypot", "round"}}, + package = {fields = {"loaders", "unload", "set"}} } globals = { diff --git a/games/devtest/mods/unittests/init.lua b/games/devtest/mods/unittests/init.lua index 22057f26a..cda4ba0eb 100644 --- a/games/devtest/mods/unittests/init.lua +++ b/games/devtest/mods/unittests/init.lua @@ -246,3 +246,9 @@ else end, }) end + +local t = {} +unittests.register("test_mod_require", function() + assert(require("unittests") == t) +end) +return t diff --git a/games/devtest/mods/unittests/misc.lua b/games/devtest/mods/unittests/misc.lua index 65dc3259e..11ca47963 100644 --- a/games/devtest/mods/unittests/misc.lua +++ b/games/devtest/mods/unittests/misc.lua @@ -353,3 +353,27 @@ local function test_ipc_poll(cb) print("delta: " .. (core.get_us_time() - t0) .. "us") end unittests.register("test_ipc_poll", test_ipc_poll) + +do + local t = require(".require") + assert(t.foo == "bar") + assert(t == require("unittests.require")) + package.unload(".require") + assert(t ~= require(".require")) + package.set(".require", "test") + assert(require(".require") == "test") +end + +do + local status, err = xpcall(function() + table.insert(package.loaders, function() + return function() + return 42 + end + end) + local answer = require("the_answer_to_life_the_universe_and_all_the_rest") + assert(answer == 42) + end, debug.traceback) + table.remove(package.loaders) + assert(status, err) +end diff --git a/games/devtest/mods/unittests/require.lua b/games/devtest/mods/unittests/require.lua new file mode 100644 index 000000000..b1b1ec73a --- /dev/null +++ b/games/devtest/mods/unittests/require.lua @@ -0,0 +1 @@ +return {foo = "bar"} diff --git a/src/script/cpp_api/s_base.cpp b/src/script/cpp_api/s_base.cpp index 9022cd8c3..e2f0c96c3 100644 --- a/src/script/cpp_api/s_base.cpp +++ b/src/script/cpp_api/s_base.cpp @@ -13,6 +13,7 @@ #include "porting.h" #include "util/string.h" #include "server.h" +#include #if CHECK_CLIENT_BUILD() #include "client/client.h" #endif @@ -229,38 +230,53 @@ std::string ScriptApiBase::getCurrentModNameInsecure(lua_State *L) return ret; } -void ScriptApiBase::loadMod(const std::string &script_path, - const std::string &mod_name) -{ - ModNameStorer mod_name_storer(getStack(), mod_name); - - loadScript(script_path); -} - -void ScriptApiBase::loadScript(const std::string &script_path) +static void load_script(lua_State *L, const char *script_path, int nresults) { verbosestream << "Loading and running script from " << script_path << std::endl; - lua_State *L = getStack(); - int error_handler = PUSH_ERROR_HANDLER(L); bool ok; if (ScriptApiSecurity::isSecure(L)) { - ok = ScriptApiSecurity::safeLoadFile(L, script_path.c_str()); + ok = ScriptApiSecurity::safeLoadFile(L, script_path); } else { - ok = !luaL_loadfile(L, script_path.c_str()); + ok = !luaL_loadfile(L, script_path); } - ok = ok && !lua_pcall(L, 0, 0, error_handler); + ok = ok && !lua_pcall(L, 0, nresults, error_handler); if (!ok) { const char *error_msg = lua_tostring(L, -1); if (!error_msg) error_msg = "(error object is not a string)"; lua_pop(L, 2); // Pop error message and error handler - throw ModError("Failed to load and run script from " + + throw ModError(std::string("Failed to load and run script from ") + script_path + ":\n" + error_msg); } - lua_pop(L, 1); // Pop error handler + lua_remove(L, error_handler); + // leave the return values from loading the file on the stack +} + +void ScriptApiBase::loadMod(const std::string &script_path, + const std::string &mod_name, bool package_set) +{ + lua_State *L = getStack(); + int top = lua_gettop(L); + ModNameStorer mod_name_storer(L, mod_name); + + load_script(L, script_path.c_str(), 1); + if (package_set) { + int module = lua_gettop(L); + lua_getglobal(L, "package"); + lua_getfield(L, -1, "set"); + lua_pushstring(L, mod_name.c_str()); + lua_pushvalue(L, module); + lua_call(L, 2, 0); + } + lua_settop(L, top); +} + +void ScriptApiBase::loadScript(const std::string &script_path) +{ + load_script(getStack(), script_path.c_str(), 0); } #if CHECK_CLIENT_BUILD() diff --git a/src/script/cpp_api/s_base.h b/src/script/cpp_api/s_base.h index b532e9cd9..7f1a3be87 100644 --- a/src/script/cpp_api/s_base.h +++ b/src/script/cpp_api/s_base.h @@ -73,7 +73,8 @@ public: DISABLE_CLASS_COPY(ScriptApiBase); // These throw a ModError on failure - void loadMod(const std::string &script_path, const std::string &mod_name); + void loadMod(const std::string &script_path, const std::string &mod_name, + bool package_set = false); void loadScript(const std::string &script_path); #if CHECK_CLIENT_BUILD() diff --git a/src/server/mods.cpp b/src/server/mods.cpp index e2d5debba..64862e6bf 100644 --- a/src/server/mods.cpp +++ b/src/server/mods.cpp @@ -44,7 +44,7 @@ void ServerModManager::loadMods(ServerScripting &script) auto t1 = porting::getTimeMs(); std::string script_path = mod.path + DIR_DELIM + "init.lua"; - script.loadMod(script_path, mod.name); + script.loadMod(script_path, mod.name, true); infostream << "Mod \"" << mod.name << "\" loaded after " << (porting::getTimeMs() - t1) << " ms" << std::endl; }