diff --git a/builtin/game/async.lua b/builtin/game/async.lua index fc286ef25..79d1aea9b 100644 --- a/builtin/game/async.lua +++ b/builtin/game/async.lua @@ -8,15 +8,24 @@ function core.async_event_handler(jobid, retval) core.async_jobs[jobid] = nil end +local job_metatable = {__index = {}} + +function job_metatable.__index:cancel() + local cancelled = core.cancel_async_callback(self.id) + if cancelled then + core.async_jobs[self.id] = nil + end + return cancelled +end + function core.handle_async(func, callback, ...) assert(type(func) == "function" and type(callback) == "function", "Invalid core.handle_async invocation") local args = {n = select("#", ...), ...} local mod_origin = core.get_last_run_mod() - local jobid = core.do_async_callback(func, args, mod_origin) - core.async_jobs[jobid] = callback + local id = core.do_async_callback(func, args, mod_origin) + core.async_jobs[id] = callback - return true + return setmetatable({id = id}, job_metatable) end - diff --git a/doc/lua_api.md b/doc/lua_api.md index b34e571c2..a15415a03 100644 --- a/doc/lua_api.md +++ b/doc/lua_api.md @@ -7166,12 +7166,12 @@ This allows you easy interoperability for delegating work to jobs. * When `func` returns the callback is called (in the normal environment) with all of the return values as arguments. * Optional: Variable number of arguments that are passed to `func` + * Returns an `AsyncJob` async job. * `core.register_async_dofile(path)`: * Register a path to a Lua file to be imported when an async environment is initialized. You can use this to preload code which you can then call later using `core.handle_async()`. - ### List of APIs available in an async environment Classes: @@ -8002,6 +8002,15 @@ use the provided load and write functions for this. * `from_file(filename)`: Experimental. Like `from_string()`, but reads the data from a file. +`AsyncJob` +---------- +An `AsyncJob` is a reference to a job to be run in an async environment. + +### Methods +* `cancel()`: try to cancel the job + * Returns whether the job was cancelled. + * A job can only be cancelled if it has not started. + `InvRef` -------- diff --git a/games/devtest/mods/unittests/async_env.lua b/games/devtest/mods/unittests/async_env.lua index b00deb3b6..d17cef1ed 100644 --- a/games/devtest/mods/unittests/async_env.lua +++ b/games/devtest/mods/unittests/async_env.lua @@ -207,3 +207,32 @@ local function test_vector_preserve(cb) end, {vec}) end unittests.register("test_async_vector", test_vector_preserve, {async=true}) + +local function test_async_job_replacement(cb) + core.ipc_set("unittests:end_blocking", nil) + local capacity = core.get_async_threading_capacity() + for _ = 1, capacity do + core.handle_async(function() + core.ipc_poll("unittests:end_blocking", 1000) + end, function() end) + end + local job = core.handle_async(function() + end, function() + return cb("Canceled async job ran") + end) + if not job:cancel() then + return cb("AsyncJob:cancel sanity check failed") + end + core.ipc_set("unittests:end_blocking", true) + + -- Try to cancel a job that is already run. + job = core.handle_async(function(x) + return x + end, function(ret) + if job:cancel() then + return cb("AsyncJob:cancel canceled a completed job") + end + cb() + end, 1) +end +unittests.register("test_async_job_replacement", test_async_job_replacement, {async=true}) diff --git a/src/script/cpp_api/s_async.cpp b/src/script/cpp_api/s_async.cpp index 982fb825e..0a4d82a60 100644 --- a/src/script/cpp_api/s_async.cpp +++ b/src/script/cpp_api/s_async.cpp @@ -96,38 +96,44 @@ void AsyncEngine::addWorkerThread() } /******************************************************************************/ -u32 AsyncEngine::queueAsyncJob(std::string &&func, std::string &¶ms, - const std::string &mod_origin) + +u32 AsyncEngine::queueAsyncJob(LuaJobInfo &&job) { MutexAutoLock autolock(jobQueueMutex); u32 jobId = jobIdCounter++; - jobQueue.emplace_back(); - auto &to_add = jobQueue.back(); - to_add.id = jobId; - to_add.function = std::move(func); - to_add.params = std::move(params); - to_add.mod_origin = mod_origin; + assert(!job.function.empty()); + job.id = jobId; + jobQueue.push_back(std::move(job)); jobQueueCounter.post(); return jobId; } +u32 AsyncEngine::queueAsyncJob(std::string &&func, std::string &¶ms, + const std::string &mod_origin) +{ + LuaJobInfo to_add(std::move(func), std::move(params), mod_origin); + return queueAsyncJob(std::move(to_add)); +} + u32 AsyncEngine::queueAsyncJob(std::string &&func, PackedValue *params, const std::string &mod_origin) +{ + LuaJobInfo to_add(std::move(func), params, mod_origin); + return queueAsyncJob(std::move(to_add)); +} + +bool AsyncEngine::cancelAsyncJob(u32 id) { MutexAutoLock autolock(jobQueueMutex); - u32 jobId = jobIdCounter++; - - jobQueue.emplace_back(); - auto &to_add = jobQueue.back(); - to_add.id = jobId; - to_add.function = std::move(func); - to_add.params_ext.reset(params); - to_add.mod_origin = mod_origin; - - jobQueueCounter.post(); - return jobId; + for (auto job = jobQueue.begin(); job != jobQueue.end(); job++) { + if (job->id == id) { + jobQueue.erase(job); + return true; + } + } + return false; } /******************************************************************************/ @@ -419,3 +425,19 @@ void* AsyncWorkerThread::run() return 0; } +u32 ScriptApiAsync::queueAsync(std::string &&serialized_func, + PackedValue *param, const std::string &mod_origin) +{ + return asyncEngine.queueAsyncJob(std::move(serialized_func), + param, mod_origin); +} + +bool ScriptApiAsync::cancelAsync(u32 id) +{ + return asyncEngine.cancelAsyncJob(id); +} + +void ScriptApiAsync::stepAsync() +{ + asyncEngine.step(getStack()); +} diff --git a/src/script/cpp_api/s_async.h b/src/script/cpp_api/s_async.h index 1b6743dea..fd82f6f36 100644 --- a/src/script/cpp_api/s_async.h +++ b/src/script/cpp_api/s_async.h @@ -26,6 +26,12 @@ class AsyncEngine; struct LuaJobInfo { LuaJobInfo() = default; + LuaJobInfo(std::string &&func, std::string &¶ms, const std::string &mod_origin = "") : + function(func), params(params), mod_origin(mod_origin) {} + LuaJobInfo(std::string &&func, PackedValue *params, const std::string &mod_origin = "") : + function(func), mod_origin(mod_origin) { + params_ext.reset(params); + } // Function to be called in async environment (from string.dump) std::string function; @@ -102,12 +108,26 @@ public: u32 queueAsyncJob(std::string &&func, PackedValue *params, const std::string &mod_origin = ""); + /** + * Try to cancel an async job + * @param id The ID of the job + * @return Whether the job was cancelled + */ + bool cancelAsyncJob(u32 id); + /** * Engine step to process finished jobs * @param L The Lua stack */ void step(lua_State *L); + /** + * Get the maximum number of threads that can be used by the async environment + */ + unsigned int getThreadingCapacity() const { + return MYMAX(workerThreads.size(), autoscaleMaxWorkers); + } + protected: /** * Get a Job from queue to be processed @@ -117,6 +137,13 @@ protected: */ bool getJob(LuaJobInfo *job); + /** + * Queue an async job + * @param job The job to queue (takes ownership!) + * @return Id of the queued job + */ + u32 queueAsyncJob(LuaJobInfo &&job); + /** * Put a Job result back to result queue * @param result result of completed job @@ -206,3 +233,23 @@ private: // Counter semaphore for job dispatching Semaphore jobQueueCounter; }; + +class ScriptApiAsync: + virtual public ScriptApiBase +{ +public: + ScriptApiAsync(Server *server): asyncEngine(server) {} + + virtual void initAsync() = 0; + void stepAsync(); + + u32 queueAsync(std::string &&serialized_func, + PackedValue *param, const std::string &mod_origin); + bool cancelAsync(u32 id); + unsigned int getThreadingCapacity() const { + return asyncEngine.getThreadingCapacity(); + } + +protected: + AsyncEngine asyncEngine; +}; diff --git a/src/script/lua_api/CMakeLists.txt b/src/script/lua_api/CMakeLists.txt index ef1be9525..c1f71ec68 100644 --- a/src/script/lua_api/CMakeLists.txt +++ b/src/script/lua_api/CMakeLists.txt @@ -1,5 +1,6 @@ set(common_SCRIPT_LUA_API_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/l_areastore.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/l_async.cpp ${CMAKE_CURRENT_SOURCE_DIR}/l_auth.cpp ${CMAKE_CURRENT_SOURCE_DIR}/l_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/l_craft.cpp diff --git a/src/script/lua_api/l_async.cpp b/src/script/lua_api/l_async.cpp new file mode 100644 index 000000000..00c5e0eb4 --- /dev/null +++ b/src/script/lua_api/l_async.cpp @@ -0,0 +1,64 @@ +// Luanti +// SPDX-License-Identifier: LGPL-2.1-or-later + +#include "lua_api/l_internal.h" +#include "lua_api/l_async.h" +#include "cpp_api/s_async.h" + +static std::string get_serialized_function(lua_State *L, int index) +{ + luaL_checktype(L, index, LUA_TFUNCTION); + call_string_dump(L, index); + size_t func_length; + const char *serialized_func_raw = lua_tolstring(L, -1, &func_length); + std::string serialized_func(serialized_func_raw, func_length); + lua_pop(L, 1); + return serialized_func; +} + +// do_async_callback(func, params, mod_origin) +int ModApiAsync::l_do_async_callback(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + ScriptApiAsync *script = getScriptApi(L); + + luaL_checktype(L, 2, LUA_TTABLE); + luaL_checktype(L, 3, LUA_TSTRING); + + auto serialized_func = get_serialized_function(L, 1); + PackedValue *param = script_pack(L, 2); + std::string mod_origin = readParam(L, 3); + + u32 jobId = script->queueAsync( + std::move(serialized_func), + param, mod_origin); + + lua_pushinteger(L, jobId); + return 1; +} + +// cancel_async_callback(id) +int ModApiAsync::l_cancel_async_callback(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + ScriptApiAsync *script = getScriptApi(L); + u32 id = luaL_checkinteger(L, 1); + lua_pushboolean(L, script->cancelAsync(id)); + return 1; +} + +// get_async_capacity() +int ModApiAsync::l_get_async_threading_capacity(lua_State *L) +{ + NO_MAP_LOCK_REQUIRED; + ScriptApiAsync *script = getScriptApi(L); + lua_pushinteger(L, script->getThreadingCapacity()); + return 1; +} + +void ModApiAsync::Initialize(lua_State *L, int top) +{ + API_FCT(do_async_callback); + API_FCT(cancel_async_callback); + API_FCT(get_async_threading_capacity); +} diff --git a/src/script/lua_api/l_async.h b/src/script/lua_api/l_async.h new file mode 100644 index 000000000..1632a9e2f --- /dev/null +++ b/src/script/lua_api/l_async.h @@ -0,0 +1,19 @@ +// Luanti +// SPDX-License-Identifier: LGPL-2.1-or-later + +#pragma once + +#include "lua_api/l_base.h" + +class ModApiAsync : public ModApiBase +{ +public: + static void Initialize(lua_State *L, int top); +private: + // do_async_callback(func, params, mod_origin) + static int l_do_async_callback(lua_State *L); + // cancel_async_callback(id) + static int l_cancel_async_callback(lua_State *L); + // get_async_threading_capacity() + static int l_get_async_threading_capacity(lua_State *L); +}; diff --git a/src/script/lua_api/l_server.cpp b/src/script/lua_api/l_server.cpp index 698b2dba6..8d49e15ed 100644 --- a/src/script/lua_api/l_server.cpp +++ b/src/script/lua_api/l_server.cpp @@ -625,33 +625,6 @@ int ModApiServer::l_notify_authentication_modified(lua_State *L) return 0; } -// do_async_callback(func, params, mod_origin) -int ModApiServer::l_do_async_callback(lua_State *L) -{ - NO_MAP_LOCK_REQUIRED; - ServerScripting *script = getScriptApi(L); - - luaL_checktype(L, 1, LUA_TFUNCTION); - luaL_checktype(L, 2, LUA_TTABLE); - luaL_checktype(L, 3, LUA_TSTRING); - - call_string_dump(L, 1); - size_t func_length; - const char *serialized_func_raw = lua_tolstring(L, -1, &func_length); - - PackedValue *param = script_pack(L, 2); - - std::string mod_origin = readParam(L, 3); - - u32 jobId = script->queueAsync( - std::string(serialized_func_raw, func_length), - param, mod_origin); - - lua_settop(L, 0); - lua_pushinteger(L, jobId); - return 1; -} - // register_async_dofile(path) int ModApiServer::l_register_async_dofile(lua_State *L) { @@ -747,7 +720,6 @@ void ModApiServer::Initialize(lua_State *L, int top) API_FCT(unban_player_or_ip); API_FCT(notify_authentication_modified); - API_FCT(do_async_callback); API_FCT(register_async_dofile); API_FCT(serialize_roundtrip); diff --git a/src/script/lua_api/l_server.h b/src/script/lua_api/l_server.h index 6de7de363..0d2253e57 100644 --- a/src/script/lua_api/l_server.h +++ b/src/script/lua_api/l_server.h @@ -100,9 +100,6 @@ private: // notify_authentication_modified(name) static int l_notify_authentication_modified(lua_State *L); - // do_async_callback(func, params, mod_origin) - static int l_do_async_callback(lua_State *L); - // register_async_dofile(path) static int l_register_async_dofile(lua_State *L); diff --git a/src/script/scripting_server.cpp b/src/script/scripting_server.cpp index f30def03d..624d1b91d 100644 --- a/src/script/scripting_server.cpp +++ b/src/script/scripting_server.cpp @@ -9,6 +9,7 @@ #include "filesys.h" #include "cpp_api/s_internal.h" #include "lua_api/l_areastore.h" +#include "lua_api/l_async.h" #include "lua_api/l_auth.h" #include "lua_api/l_base.h" #include "lua_api/l_craft.h" @@ -39,7 +40,7 @@ extern "C" { ServerScripting::ServerScripting(Server* server): ScriptApiBase(ScriptingType::Server), - asyncEngine(server) + ScriptApiAsync(server) { setGameDef(server); @@ -115,18 +116,6 @@ void ServerScripting::initAsync() asyncEngine.initialize(0); } -void ServerScripting::stepAsync() -{ - asyncEngine.step(getStack()); -} - -u32 ServerScripting::queueAsync(std::string &&serialized_func, - PackedValue *param, const std::string &mod_origin) -{ - return asyncEngine.queueAsyncJob(std::move(serialized_func), - param, mod_origin); -} - void ServerScripting::InitializeModApi(lua_State *L, int top) { // Register reference classes (userdata) @@ -150,6 +139,7 @@ void ServerScripting::InitializeModApi(lua_State *L, int top) ModChannelRef::Register(L); // Initialize mod api modules + ModApiAsync::Initialize(L, top); ModApiAuth::Initialize(L, top); ModApiCraft::Initialize(L, top); ModApiEnv::Initialize(L, top); diff --git a/src/script/scripting_server.h b/src/script/scripting_server.h index 6c0583553..8b661a476 100644 --- a/src/script/scripting_server.h +++ b/src/script/scripting_server.h @@ -22,6 +22,7 @@ struct PackedValue; class ServerScripting: virtual public ScriptApiBase, + public ScriptApiAsync, public ScriptApiDetached, public ScriptApiEntity, public ScriptApiEnv, @@ -41,14 +42,7 @@ public: void saveGlobals(); // Initialize async engine, call this AFTER loading all mods - void initAsync(); - - // Global step handler to collect async results - void stepAsync(); - - // Pass job to async threads - u32 queueAsync(std::string &&serialized_func, - PackedValue *param, const std::string &mod_origin); + void initAsync() override; protected: // from ScriptApiSecurity: @@ -63,6 +57,4 @@ private: void InitializeModApi(lua_State *L, int top); static void InitializeAsync(lua_State *L, int top); - - AsyncEngine asyncEngine; };