local pcall = copcall or pcall --- @param ... any --- @return {[integer]: any, n: integer} local function pack_len(...) return { n = select('#', ...), ... } end --- like unpack() but use the length set by F.pack_len if present --- @param t? { [integer]: any, n?: integer } --- @param first? integer --- @return ...any local function unpack_len(t, first) if t then return unpack(t, first or 1, t.n or table.maxn(t)) end end --- @class async local M = {} --- Weak table to keep track of running tasks --- @type table local threads = setmetatable({}, { __mode = 'k' }) --- @return async.Task? local function running() local task = threads[coroutine.running()] if task and not (task:_completed() or task._closing) then return task end end --- Base class for async tasks. Async functions should return a subclass of --- this. This is designed specifically to be a base class of uv_handle_t --- @class async.Handle --- @field close fun(self: async.Handle, callback?: fun()) --- @field is_closing? fun(self: async.Handle): boolean --- @alias async.CallbackFn fun(...: any): async.Handle? --- @class async.Task : async.Handle --- @field package _callbacks table --- @field package _callback_pos integer --- @field private _thread thread --- --- Tasks can call other async functions (task of callback functions) --- when we are waiting on a child, we store the handle to it here so we can --- cancel it. --- @field private _current_child? async.Handle --- --- Error result of the task is an error occurs. --- Must use `await` to get the result. --- @field private _err? any --- --- Result of the task. --- Must use `await` to get the result. --- @field private _result? any[] local Task = {} Task.__index = Task --- @private --- @param func function --- @return async.Task function Task._new(func) local thread = coroutine.create(func) local self = setmetatable({ _closing = false, _thread = thread, _callbacks = {}, _callback_pos = 1, }, Task) threads[thread] = self return self end --- @param callback fun(err?: any, ...: any) function Task:await(callback) if self._closing then callback('closing') elseif self:_completed() then -- TODO(lewis6991): test -- Already finished or closed callback(self._err, unpack_len(self._result)) else self._callbacks[self._callback_pos] = callback self._callback_pos = self._callback_pos + 1 end end --- @package function Task:_completed() return (self._err or self._result) ~= nil end -- Use max 32-bit signed int value to avoid overflow on 32-bit systems. -- Do not use `math.huge` as it is not interpreted as a positive integer on all -- platforms. local MAX_TIMEOUT = 2 ^ 31 - 1 --- Synchronously wait (protected) for a task to finish (blocking) --- --- If an error is returned, `Task:traceback()` can be used to get the --- stack trace of the error. --- --- Example: --- ```lua --- --- local ok, err_or_result = task:pwait(10) --- --- if not ok then --- error(task:traceback(err_or_result)) --- end --- --- local _, result = assert(task:pwait(10)) --- ``` --- --- Can be called if a task is closing. --- @param timeout? integer --- @return boolean status --- @return any ... result or error function Task:pwait(timeout) local done = vim.wait(timeout or MAX_TIMEOUT, function() -- Note we use self:_completed() instead of self:await() to avoid creating a -- callback. This avoids having to cleanup/unregister any callback in the -- case of a timeout. return self:_completed() end) if not done then return false, 'timeout' elseif self._err then return false, self._err else return true, unpack_len(self._result) end end --- Synchronously wait for a task to finish (blocking) --- --- Example: --- ```lua --- local result = task:wait(10) -- wait for 10ms or else error --- --- local result = task:wait() -- wait indefinitely --- ``` --- @param timeout? integer Timeout in milliseconds --- @return any ... result function Task:wait(timeout) local res = pack_len(self:pwait(timeout)) local stat = res[1] if not stat then error(self:traceback(res[2])) end return unpack_len(res, 2) end --- @private --- @param msg? string --- @param _lvl? integer --- @return string function Task:_traceback(msg, _lvl) _lvl = _lvl or 0 local thread = ('[%s] '):format(self._thread) local child = self._current_child if getmetatable(child) == Task then --- @cast child async.Task msg = child:_traceback(msg, _lvl + 1) end local tblvl = getmetatable(child) == Task and 2 or nil msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread) if _lvl == 0 then --- @type string msg = msg :gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1) :gsub('\nstack traceback:\n', '\n') :gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1) end return msg end --- Get the traceback of a task when it is not active. --- Will also get the traceback of nested tasks. --- --- @param msg? string --- @return string function Task:traceback(msg) return self:_traceback(msg) end --- If a task completes with an error, raise the error function Task:raise_on_error() self:await(function(err) if err then error(self:_traceback(err), 0) end end) return self end --- @private --- @param err? any --- @param result? {[integer]: any, n: integer} function Task:_finish(err, result) self._current_child = nil self._err = err self._result = result threads[self._thread] = nil local errs = {} --- @type string[] for _, cb in pairs(self._callbacks) do --- @type boolean, string local ok, cb_err = pcall(cb, err, unpack_len(result)) if not ok then errs[#errs + 1] = cb_err end end if #errs > 0 then error(table.concat(errs, '\n'), 0) end end --- @return boolean function Task:is_closing() return self._closing end --- Close the task and all its children. --- If callback is provided it will run asynchronously, --- else it will run synchronously. --- --- @param callback? fun() function Task:close(callback) if self:_completed() then if callback then callback() end return end if self._closing then return end self._closing = true if callback then -- async if self._current_child then self._current_child:close(function() self:_finish('closed') callback() end) else self:_finish('closed') callback() end else -- sync if self._current_child then self._current_child:close(function() self:_finish('closed') end) else self:_finish('closed') end vim.wait(0, function() return self:_completed() end) end end --- @param obj any --- @return boolean local function is_async_handle(obj) local ty = type(obj) return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close) end --- @param ... any function Task:_resume(...) --- @type [boolean, string|async.CallbackFn] local ret = pack_len(coroutine.resume(self._thread, ...)) local stat = ret[1] if not stat then -- Coroutine had error self:_finish(ret[2]) elseif coroutine.status(self._thread) == 'dead' then -- Coroutine finished local result = pack_len(unpack_len(ret, 2)) self:_finish(nil, result) else local fn = ret[2] --- @cast fn -string -- TODO(lewis6991): refine error handler to be more specific local ok, r ok, r = pcall(fn, function(...) if is_async_handle(r) then --- @cast r async.Handle -- We must close children before we resume to ensure -- all resources are collected. local args = pack_len(...) r:close(function() self:_resume(unpack_len(args)) end) else self:_resume(...) end end) if not ok then self:_finish(r) elseif is_async_handle(r) then self._current_child = r end end end --- @return 'running'|'suspended'|'normal'|'dead'? function Task:status() return coroutine.status(self._thread) end --- Run a function in an async context, asynchronously. --- --- Examples: --- ```lua --- -- The two below blocks are equivalent: --- --- -- Run a uv function and wait for it --- local stat = async.arun(function() --- return async.await(2, vim.uv.fs_stat, 'foo.txt') --- end):wait() --- --- -- Since uv functions have sync versions. You can just do: --- local stat = vim.fs_stat('foo.txt') --- ``` --- @param func function --- @param ... any --- @return async.Task function M.arun(func, ...) local task = Task._new(func) task:_resume(...) return task end --- @class async.TaskFun --- @field package _fun fun(...: any): any --- @operator call(...): any local TaskFun = {} TaskFun.__index = TaskFun function TaskFun:__call(...) return M.arun(self._fun, ...) end --- Create an async function --- @param fun function --- @return async.TaskFun function M.async(fun) return setmetatable({ _fun = fun }, TaskFun) end --- Returns the status of a task’s thread. --- --- @param task? async.Task --- @return 'running'|'suspended'|'normal'|'dead'? function M.status(task) task = task or running() if task then assert(getmetatable(task) == Task, 'Expected Task') return task:status() end end --- @async --- @generic R1, R2, R3, R4 --- @param fun fun(callback: fun(r1: R1, r2: R2, r3: R3, r4: R4)): any? --- @return R1, R2, R3, R4 local function yield(fun) assert(type(fun) == 'function', 'Expected function') return coroutine.yield(fun) end --- @async --- @param task async.Task --- @return any ... local function await_task(task) --- @param callback fun(err?: string, ...: any) --- @return function local res = pack_len(yield(function(callback) task:await(callback) return task end)) local err = res[1] if err then -- TODO(lewis6991): what is the correct level to pass? error(err, 0) end return unpack_len(res, 2) end --- Asynchronous blocking wait --- @param argc integer --- @param fun async.CallbackFn --- @param ... any func arguments --- @return any ... local function await_cbfun(argc, fun, ...) local args = pack_len(...) --- @param callback fun(...:any) --- @return any? return yield(function(callback) args[argc] = callback args.n = math.max(args.n, argc) return fun(unpack_len(args)) end) end --- @param taskfun async.TaskFun --- @param ... any --- @return any ... local function await_taskfun(taskfun, ...) return taskfun._fun(...) end --- Asynchronous blocking wait --- --- Example: --- ```lua --- local task = async.arun(function() --- return 1, 'a' --- end) --- --- local task_fun = async.async(function(arg) --- return 2, 'b', arg --- end) --- --- async.arun(function() --- do -- await a callback function --- async.await(1, vim.schedule) --- end --- --- do -- await a task (new async context) --- local n, s = async.await(task) --- assert(n == 1 and s == 'a') --- end --- --- do -- await a started task function (new async context) --- local n, s, arg = async.await(task_fun('A')) --- assert(n == 2) --- assert(s == 'b') --- assert(args == 'A') --- end --- --- do -- await a task function (re-using the current async context) --- local n, s, arg = async.await(task_fun, 'B') --- assert(n == 2) --- assert(s == 'b') --- assert(args == 'B') --- end --- end) --- ``` --- @async --- @overload fun(argc: integer, func: async.CallbackFn, ...:any): any ... --- @overload fun(task: async.Task): any ... --- @overload fun(taskfun: async.TaskFun): any ... function M.await(...) assert(running(), 'Not in async context') local arg1 = select(1, ...) if type(arg1) == 'number' then return await_cbfun(...) elseif getmetatable(arg1) == Task then return await_task(...) elseif getmetatable(arg1) == TaskFun then return await_taskfun(...) end error('Invalid arguments, expected Task or (argc, func) got: ' .. type(arg1), 2) end --- Creates an async function with a callback style function. --- --- Example: --- --- ```lua --- --- Note the callback argument is not present in the return function --- --- @type fun(timeout: integer) --- local sleep = async.awrap(2, function(timeout, callback) --- local timer = vim.uv.new_timer() --- timer:start(timeout * 1000, 0, callback) --- -- uv_timer_t provides a close method so timer will be --- -- cleaned up when this function finishes --- return timer --- end) --- --- async.arun(function() --- print('hello') --- sleep(2) --- print('world') --- end) --- ``` --- --- local atimer = async.awrap( --- @param argc integer --- @param func async.CallbackFn --- @return async function function M.awrap(argc, func) assert(type(argc) == 'number') assert(type(func) == 'function') --- @async return function(...) return M.await(argc, func, ...) end end if vim.schedule then --- An async function that when called will yield to the Neovim scheduler to be --- able to call the API. M.schedule = M.awrap(1, vim.schedule) end --- Create a function that runs a function when it is garbage collected. --- @generic F --- @param f F --- @param gc fun() --- @return F local function gc_fun(f, gc) local proxy = newproxy(true) local proxy_mt = getmetatable(proxy) proxy_mt.__gc = gc proxy_mt.__call = function(_, ...) return f(...) end return proxy end --- @param task_cbs table local function gc_cbs(task_cbs) for task, tcb in pairs(task_cbs) do for j, cb in pairs(task._callbacks) do if cb == tcb then task._callbacks[j] = nil break end end end end --- @async --- Example: --- ```lua --- local task1 = async.arun(function() --- return 1, 'a' --- end) --- --- local task2 = async.arun(function() --- return 1, 'a' --- end) --- --- local task3 = async.arun(function() --- error('task3 error') --- end) --- --- async.arun(function() --- for i, err, r1, r2 in async.iter({task1, task2, task3}) --- print(i, err, r1, r2) --- end --- end) --- ``` --- --- Prints: --- ``` --- 1 nil 1 'a' --- 2 nil 2 'b' --- 3 'task3 error' nil nil --- ``` --- --- @param tasks async.Task[] --- @return fun(): (integer?, any?, ...) function M.iter(tasks) assert(running(), 'Not in async context') local results = {} --- @type [integer, any, ...][] -- Iter blocks in an async context so only one waiter is needed local waiter = nil local task_cbs = {} --- @type table local remaining = #tasks --- If can_gc_cbs is true, then the iterator function has been garbage --- collected and means any awaiters can also be garbage collected. The --- only time we can't do this is if with the special case when iter() is --- called anonymously (`local i = async.iter(tasks)()`), so we should not --- garbage collect the callbacks until at least one awaiter is called. local can_gc_cbs = false for i, task in ipairs(tasks) do local function cb(err, ...) if can_gc_cbs == true then gc_cbs(task_cbs) end local callback = waiter -- Clear waiter before calling it waiter = nil remaining = remaining - 1 if callback then -- Iterator is waiting, yield to it callback(i, err, ...) else -- Task finished before Iterator was called. Store results. table.insert(results, pack_len(i, err, ...)) end end task_cbs[task] = cb task:await(cb) end return gc_fun( M.awrap(1, function(callback) if next(results) then local res = table.remove(results, 1) callback(unpack_len(res)) elseif remaining == 0 then callback() -- finish else assert(not waiter, 'internal error: waiter already set') waiter = callback end end), function() -- Don't gc callbacks just yet. Wait until at least one of them is called. can_gc_cbs = true end ) end do -- join() --- @param results table --- @param i integer --- @param ... any --- @return boolean local function collect(results, i, ...) if i then results[i] = pack_len(...) end return i ~= nil end --- @param iter fun(): ... --- @return table local function drain_iter(iter) local results = {} --- @type table while collect(results, iter()) do end return results end --- @async --- Wait for all tasks to finish and return their results. --- --- Example: --- ```lua --- local task1 = async.arun(function() --- return 1, 'a' --- end) --- --- local task2 = async.arun(function() --- return 1, 'a' --- end) --- --- local task3 = async.arun(function() --- error('task3 error') --- end) --- --- async.arun(function() --- local results = async.join({task1, task2, task3}) --- print(vim.inspect(results)) --- end) --- ``` --- --- Prints: --- ``` --- { --- [1] = { nil, 1, 'a' }, --- [2] = { nil, 2, 'b' }, --- [3] = { 'task2 error' }, --- } --- ``` --- @param tasks async.Task[] --- @return table function M.join(tasks) assert(running(), 'Not in async context') return drain_iter(M.iter(tasks)) end --- @async --- @param tasks async.Task[] --- @return integer?, any?, ...? function M.joinany(tasks) return M.iter(tasks)() end end return M