mirror of
https://github.com/luanti-org/luanti.git
synced 2025-06-27 16:36:03 +00:00
Implement safe require
This commit is contained in:
parent
78293404c7
commit
7273b664ae
9 changed files with 165 additions and 20 deletions
|
@ -49,6 +49,72 @@ dofile(commonpath .. "serialize.lua")
|
|||
dofile(commonpath .. "misc_helpers.lua")
|
||||
|
||||
if INIT == "game" then
|
||||
local function mod_loader(module_name)
|
||||
local parts = module_name:split(".")
|
||||
local modname = parts[1]
|
||||
local modpath = core.get_modpath(modname)
|
||||
if not modpath then
|
||||
return "no mod called " .. modname
|
||||
end
|
||||
parts[1] = modpath
|
||||
local base_path = table.concat(parts, "/")
|
||||
local errors = {}
|
||||
for _, suffix in ipairs({"/init.lua", #parts > 1 and ".lua" or nil}) do
|
||||
local source_path = base_path .. suffix
|
||||
local f, err = io.open(source_path, "r")
|
||||
if f then
|
||||
local source = f:read("*all")
|
||||
local chunk, load_err = loadstring(source, source_path)
|
||||
f:close()
|
||||
assert(chunk, load_err)
|
||||
return chunk
|
||||
else
|
||||
table.insert(errors, err)
|
||||
end
|
||||
end
|
||||
return table.concat(errors, "\n")
|
||||
end
|
||||
|
||||
local function preprocess_module_name(module_name)
|
||||
if module_name:sub(1, 1) == "." then
|
||||
module_name = assert(core.get_current_modname()) .. module_name
|
||||
end
|
||||
assert(module_name:find("^[A-Za-z0-9_%.]+$") and
|
||||
module_name:find("%.%.") == nil and
|
||||
module_name:sub(-1) ~= ".",
|
||||
"invalid module name")
|
||||
return module_name
|
||||
end
|
||||
|
||||
local loaded = {} -- [module_name] = function() return module end
|
||||
package = {
|
||||
loaders = {mod_loader},
|
||||
unload = function(module_name)
|
||||
loaded[preprocess_module_name(module_name)] = nil
|
||||
end,
|
||||
set = function(module_name, module)
|
||||
loaded[preprocess_module_name(module_name)] = function() return module end
|
||||
end,
|
||||
}
|
||||
function require(module_name)
|
||||
module_name = preprocess_module_name(module_name)
|
||||
local module_func = loaded[module_name]
|
||||
if module_func then
|
||||
return module_func()
|
||||
end
|
||||
local errors = {}
|
||||
for _, loader in ipairs(package.loaders) do
|
||||
local res = loader(module_name)
|
||||
if type(res) == "function" then
|
||||
local module = res()
|
||||
package.set(module_name, module)
|
||||
return module
|
||||
end
|
||||
table.insert(errors, res)
|
||||
end
|
||||
error("failed to load module '" .. module_name .. '":\n' .. table.concat(errors, "\n"))
|
||||
end
|
||||
|
||||
dofile(scriptdir .. "game" .. DIR_DELIM .. "init.lua")
|
||||
assert(not core.get_http_api)
|
||||
elseif INIT == "mainmenu" then
|
||||
|
|
|
@ -4121,6 +4121,36 @@ For example:
|
|||
* `core.dir_to_wallmounted` (Involves wallmounted param2 values.)
|
||||
|
||||
|
||||
Loading files
|
||||
=============
|
||||
|
||||
The typical way to load files until 5.13.0 was to use `dofile`
|
||||
and run files in order of dependencies in `init.lua`, e.g.
|
||||
|
||||
```lua
|
||||
local modname = core.get_current_modname()
|
||||
local modpath = core.get_modpath(modname)
|
||||
dofile(modname .. "/stuff.lua")
|
||||
dofile(modname .. "/more_stuff.lua")
|
||||
```
|
||||
|
||||
This is clunky and has several drawbacks.
|
||||
As of version 5.13.0, Luanti supports a custom version of Lua's `require`:
|
||||
|
||||
* `require("mymod")` gives you whatever `init.lua` of `mymod` returns.
|
||||
You should have a dependency (optional or not) on `mymod` if you call this.
|
||||
* `require("mod.dir.file")` loads `dir/file.lua` or `dir/file/init.lua` from the mod folder of `mod`.
|
||||
* For convenience, `require(".dir.file")` is equivalent to `require(core.get_current_modname() .. ".dir.file")`.
|
||||
This is also supported by the `package.*` functions.
|
||||
|
||||
The implementation is customizable via the `package` table:
|
||||
|
||||
* `package.loaders` is a list of loaders which are tried in order.
|
||||
A loader is a `function(module_name)` which returns `nil`,
|
||||
a string explaining why it couldn't load the module,
|
||||
or a function that when called returns the module.
|
||||
* `package.unload(module_name)` can be used to forcibly unload a module.
|
||||
* `package.set(module_name, module)` can be used to override a module.
|
||||
|
||||
|
||||
Helper functions
|
||||
|
|
|
@ -33,6 +33,7 @@ read_globals = {
|
|||
string = {fields = {"split", "trim"}},
|
||||
table = {fields = {"copy", "getn", "indexof", "insert_all", "key_value_swap"}},
|
||||
math = {fields = {"hypot", "round"}},
|
||||
package = {fields = {"loaders", "unload", "set"}}
|
||||
}
|
||||
|
||||
globals = {
|
||||
|
|
|
@ -246,3 +246,9 @@ else
|
|||
end,
|
||||
})
|
||||
end
|
||||
|
||||
local t = {}
|
||||
unittests.register("test_mod_require", function()
|
||||
assert(require("unittests") == t)
|
||||
end)
|
||||
return t
|
||||
|
|
|
@ -353,3 +353,27 @@ local function test_ipc_poll(cb)
|
|||
print("delta: " .. (core.get_us_time() - t0) .. "us")
|
||||
end
|
||||
unittests.register("test_ipc_poll", test_ipc_poll)
|
||||
|
||||
do
|
||||
local t = require(".require")
|
||||
assert(t.foo == "bar")
|
||||
assert(t == require("unittests.require"))
|
||||
package.unload(".require")
|
||||
assert(t ~= require(".require"))
|
||||
package.set(".require", "test")
|
||||
assert(require(".require") == "test")
|
||||
end
|
||||
|
||||
do
|
||||
local status, err = xpcall(function()
|
||||
table.insert(package.loaders, function()
|
||||
return function()
|
||||
return 42
|
||||
end
|
||||
end)
|
||||
local answer = require("the_answer_to_life_the_universe_and_all_the_rest")
|
||||
assert(answer == 42)
|
||||
end, debug.traceback)
|
||||
table.remove(package.loaders)
|
||||
assert(status, err)
|
||||
end
|
||||
|
|
1
games/devtest/mods/unittests/require.lua
Normal file
1
games/devtest/mods/unittests/require.lua
Normal file
|
@ -0,0 +1 @@
|
|||
return {foo = "bar"}
|
|
@ -13,6 +13,7 @@
|
|||
#include "porting.h"
|
||||
#include "util/string.h"
|
||||
#include "server.h"
|
||||
#include <lua.h>
|
||||
#if CHECK_CLIENT_BUILD()
|
||||
#include "client/client.h"
|
||||
#endif
|
||||
|
@ -229,38 +230,53 @@ std::string ScriptApiBase::getCurrentModNameInsecure(lua_State *L)
|
|||
return ret;
|
||||
}
|
||||
|
||||
void ScriptApiBase::loadMod(const std::string &script_path,
|
||||
const std::string &mod_name)
|
||||
{
|
||||
ModNameStorer mod_name_storer(getStack(), mod_name);
|
||||
|
||||
loadScript(script_path);
|
||||
}
|
||||
|
||||
void ScriptApiBase::loadScript(const std::string &script_path)
|
||||
static void load_script(lua_State *L, const char *script_path, int nresults)
|
||||
{
|
||||
verbosestream << "Loading and running script from " << script_path << std::endl;
|
||||
|
||||
lua_State *L = getStack();
|
||||
|
||||
int error_handler = PUSH_ERROR_HANDLER(L);
|
||||
|
||||
bool ok;
|
||||
if (ScriptApiSecurity::isSecure(L)) {
|
||||
ok = ScriptApiSecurity::safeLoadFile(L, script_path.c_str());
|
||||
ok = ScriptApiSecurity::safeLoadFile(L, script_path);
|
||||
} else {
|
||||
ok = !luaL_loadfile(L, script_path.c_str());
|
||||
ok = !luaL_loadfile(L, script_path);
|
||||
}
|
||||
ok = ok && !lua_pcall(L, 0, 0, error_handler);
|
||||
ok = ok && !lua_pcall(L, 0, nresults, error_handler);
|
||||
if (!ok) {
|
||||
const char *error_msg = lua_tostring(L, -1);
|
||||
if (!error_msg)
|
||||
error_msg = "(error object is not a string)";
|
||||
lua_pop(L, 2); // Pop error message and error handler
|
||||
throw ModError("Failed to load and run script from " +
|
||||
throw ModError(std::string("Failed to load and run script from ") +
|
||||
script_path + ":\n" + error_msg);
|
||||
}
|
||||
lua_pop(L, 1); // Pop error handler
|
||||
lua_remove(L, error_handler);
|
||||
// leave the return values from loading the file on the stack
|
||||
}
|
||||
|
||||
void ScriptApiBase::loadMod(const std::string &script_path,
|
||||
const std::string &mod_name, bool package_set)
|
||||
{
|
||||
lua_State *L = getStack();
|
||||
int top = lua_gettop(L);
|
||||
ModNameStorer mod_name_storer(L, mod_name);
|
||||
|
||||
load_script(L, script_path.c_str(), 1);
|
||||
if (package_set) {
|
||||
int module = lua_gettop(L);
|
||||
lua_getglobal(L, "package");
|
||||
lua_getfield(L, -1, "set");
|
||||
lua_pushstring(L, mod_name.c_str());
|
||||
lua_pushvalue(L, module);
|
||||
lua_call(L, 2, 0);
|
||||
}
|
||||
lua_settop(L, top);
|
||||
}
|
||||
|
||||
void ScriptApiBase::loadScript(const std::string &script_path)
|
||||
{
|
||||
load_script(getStack(), script_path.c_str(), 0);
|
||||
}
|
||||
|
||||
#if CHECK_CLIENT_BUILD()
|
||||
|
|
|
@ -73,7 +73,8 @@ public:
|
|||
DISABLE_CLASS_COPY(ScriptApiBase);
|
||||
|
||||
// These throw a ModError on failure
|
||||
void loadMod(const std::string &script_path, const std::string &mod_name);
|
||||
void loadMod(const std::string &script_path, const std::string &mod_name,
|
||||
bool package_set = false);
|
||||
void loadScript(const std::string &script_path);
|
||||
|
||||
#if CHECK_CLIENT_BUILD()
|
||||
|
|
|
@ -44,7 +44,7 @@ void ServerModManager::loadMods(ServerScripting &script)
|
|||
|
||||
auto t1 = porting::getTimeMs();
|
||||
std::string script_path = mod.path + DIR_DELIM + "init.lua";
|
||||
script.loadMod(script_path, mod.name);
|
||||
script.loadMod(script_path, mod.name, true);
|
||||
infostream << "Mod \"" << mod.name << "\" loaded after "
|
||||
<< (porting::getTimeMs() - t1) << " ms" << std::endl;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue