From 92c302b862a8fa52f0c6b6b223f1ef7aee28f4fd Mon Sep 17 00:00:00 2001 From: TheLeoP Date: Thu, 15 Jun 2023 07:36:52 -0500 Subject: [PATCH] fix: make locals work again --- lua/nvim-treesitter/locals.lua | 121 ++++++++++++++++++++++----------- 1 file changed, 82 insertions(+), 39 deletions(-) diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index f045cae19..8ea264e4d 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -73,8 +73,8 @@ function M.iter_scope_tree(node, bufnr) end -- Gets a table of all nodes and their 'kinds' from a locals list ----@param local_def any: the local list result ----@return table: a list of node entries +---@param local_def TSLocal[] the local list result +---@return TSLocal[] a list of node entries function M.get_local_nodes(local_def) local result = {} @@ -91,7 +91,7 @@ end -- * The node -- * The full definition match `@definition.var.something` -> 'var.something' -- * The last definition match `@definition.var.something` -> 'something' ----@param local_def any The locals result +---@param local_def TSLocal The locals result ---@param accumulator function The accumulator function ---@param full_match? string The full match path to append to ---@param last_match? string The last match @@ -125,66 +125,107 @@ local function memoize(fn, hash_fn) return function(...) local key = hash_fn(...) if cache[key] == nil then - local v = fn(...) ---@type any - cache[key] = v ~= nil and v or vim.NIL + local v = { fn(...) } ---@type any + + for k, value in pairs(v) do + if value == nil then + value[k] = vim.NIL + end + end + + cache[key] = v end local v = cache[key] - return v ~= vim.NIL and v or nil + + for k, value in pairs(v) do + if value == vim.NIL then + value[k] = nil + end + end + + return unpack(v) end end - -local function get_query(bufnr) +---@param bufnr integer: the buffer +---@return TSNode|nil root: root node of the buffer +local function get_root(bufnr) local parser = assert(ts.get_parser(bufnr)) if not parser then return end + parser:parse() + return parser:trees()[1]:root() +end + +---@param bufnr integer: the buffer +---@return Query|nil query: `locals` query +---@return TSNode|nil root: root node of the bufferocal function get_query(bufnr) +local function get_query(bufnr) + local root = get_root(bufnr) local ft = vim.bo[bufnr].filetype local lang = ts.language.get_lang(ft) or ft local query = (ts.query.get(lang, 'locals')) - parser:parse() - local root = parser:trees():root() - return query, root end +---@alias TSScope "parent"|"local"|"global" + +---@class TSLocal +---@field kind string +---@field node TSNode +---@field scope TSScope + -- Return all locals for the buffer -- -- memoized by buffer tick -- ---@param bufnr integer buffer ----@return table? definitions ----@return table? references ----@return table? scopes +---@return TSLocal[] definitions +---@return TSLocal[] references +---@return TSNode[] scopes M.get = memoize(function(bufnr) local query, root = get_query(bufnr) - if not query then - return + if not query or not root then + return {}, {}, {} end local definitions = {} local scopes = {} local references = {} - for _, loc in query:iter_captures(root, bufnr) do - if loc.definition then - table.insert(definitions, loc.definition) + for id, node, metadata in query:iter_captures(root, bufnr) do + local kind = query.captures[id] + + local scope = 'local' ---@type string + for k, v in pairs(metadata) do + if type(k) == 'string' and vim.endswith(k, 'scope') then + scope = v + end end - if loc.scope and loc.scope.node then - table.insert(scopes, loc.scope.node) + if node and vim.startswith(kind, 'definition') then + table.insert(definitions, { kind = kind, node = node, scope = scope }) end - if loc.reference and loc.reference.node then - table.insert(references, loc.reference.node) + if node and kind == 'scope' then + table.insert(scopes, node) + end + + if node and kind == 'reference' then + table.insert(references, { kind = kind, node = node, scope = scope }) end end return definitions, references, scopes end, function(bufnr) - return tostring(bufnr) + local root = get_root(bufnr) + if not root then + return tostring(bufnr) + end + return tostring(root:id()) end) -- Get a single dimension table to look definition nodes. @@ -199,7 +240,7 @@ end) -- is called very frequently, which is why this lookup must be fast as possible. -- ---@param bufnr integer: the buffer ----@return table result: a table for looking up definitions +---@return TSLocal[] result: a table for looking up definitions M.get_definitions_lookup_table = memoize(function(bufnr) local definitions, _, _ = M.get(bufnr) if not definitions then @@ -221,7 +262,11 @@ M.get_definitions_lookup_table = memoize(function(bufnr) return result end, function(bufnr) - return tostring(bufnr) + local root = get_root(bufnr) + if not root then + return tostring(bufnr) + end + return tostring(root:id()) end) -- Gets all the scopes of a definition based on the scope type @@ -233,7 +278,7 @@ end) -- ---@param node TSNode: the definition node ---@param bufnr integer: the buffer ----@param scope_type string: the scope type +---@param scope_type TSScope: the scope type function M.get_definition_scopes(node, bufnr, scope_type) local scopes = {} local scope_count = 1 ---@type integer|nil @@ -248,8 +293,8 @@ function M.get_definition_scopes(node, bufnr, scope_type) end local i = 0 - for scope in M.iter_scope_tree(node, bufnr) do - table.insert(scopes, scope) + for scope_node in M.iter_scope_tree(node, bufnr) do + table.insert(scopes, scope_node) i = i + 1 if scope_count and i >= scope_count then @@ -284,7 +329,8 @@ end -- Finds usages of a node in a given scope. ---@param node TSNode the node to find usages for ----@param scope_node TSNode the node to look within +---@param scope_node TSNode|nil the node to look within +---@param bufnr integer|nil the bufnr to look into ---@return TSNode[]: a list of nodes function M.find_usages(node, scope_node, bufnr) bufnr = bufnr or api.nvim_get_current_buf() @@ -302,17 +348,14 @@ function M.find_usages(node, scope_node, bufnr) return {} end - for match in query:iter_matches(scope_node, bufnr) do + for id, node_capture in query:iter_captures(scope_node, bufnr) do + local kind = query.captures[id] if - match.reference - and match.reference.node - and ts.get_node_text(match.reference.node, bufnr) == node_text + node_capture + and kind == 'reference' + and ts.get_node_text(node_capture, bufnr) == node_text then - local def_node, _, kind = M.find_definition(match.reference.node, bufnr) - - if kind == nil or def_node == node then - table.insert(usages, match.reference.node) - end + table.insert(usages, node_capture) end end