refactor(all): language tree adaption (#1105)

This commit is contained in:
Steven Sojka 2021-03-30 08:18:24 -05:00 committed by GitHub
parent 0df7c4aa39
commit 6863f79118
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 198 additions and 61 deletions

View file

@ -7,6 +7,8 @@ local caching = require'nvim-treesitter.caching'
local M = {}
local EMPTY_ITER = function() end
M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'}
-- Creates a function that checks whether a given query exists
@ -166,7 +168,7 @@ end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function
function M.get_capture_matches(bufnr, capture_string, query_group)
function M.get_capture_matches(bufnr, capture_string, query_group, root, lang)
if not string.sub(capture_string, 1,2) == '@' then
print('capture_string must start with "@"')
return
@ -176,7 +178,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
capture_string = string.sub(capture_string, 2)
local matches = {}
for match in M.iter_group_results(bufnr, query_group) do
for match in M.iter_group_results(bufnr, query_group, root, lang) do
local insert = utils.get_at_path(match, capture_string)
if insert then
@ -186,7 +188,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
return matches
end
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
if not string.sub(capture_string, 1,2) == '@' then
api.nvim_err_writeln('capture_string must start with "@"')
return
@ -198,7 +200,7 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate,
local best
local best_score
for maybe_match in M.iter_group_results(bufnr, query_group) do
for maybe_match in M.iter_group_results(bufnr, query_group, root) do
local match = utils.get_at_path(maybe_match, capture_string)
if match and filter_predicate(match) then
@ -220,31 +222,82 @@ end
-- @param bufnr the buffer
-- @param query_group the query file to use
-- @param root the root node
function M.iter_group_results(bufnr, query_group, root)
local lang = parsers.get_buf_lang(bufnr)
if not lang then return function() end end
-- @param root the root node lang, if known
function M.iter_group_results(bufnr, query_group, root, root_lang)
local buf_lang = parsers.get_buf_lang(bufnr)
local query = M.get_query(lang, query_group)
if not query then return function() end end
if not buf_lang then return EMPTY_ITER end
local parser = parsers.get_parser(bufnr, lang)
if not parser then return function() end end
local parser = parsers.get_parser(bufnr, buf_lang)
if not parser then return EMPTY_ITER end
local root = root or parser:parse()[1]:root()
local start_row, _, end_row, _ = root:range()
if not root then
local first_tree = parser:trees()[1]
if first_tree then
root = first_tree:root()
end
end
if not root then return EMPTY_ITER end
local range = {root:range()}
if not root_lang then
local lang_tree = parser:language_for_range(range)
if lang_tree then
root_lang = lang_tree:lang()
end
end
if not root_lang then return EMPTY_ITER end
local query = M.get_query(root_lang, query_group)
if not query then return EMPTY_ITER end
-- The end row is exclusive so we need to add 1 to it.
return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1)
return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1)
end
function M.collect_group_results(bufnr, query_group, root)
function M.collect_group_results(bufnr, query_group, root, lang)
local matches = {}
for prepared_match in M.iter_group_results(bufnr, query_group, root) do
for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do
table.insert(matches, prepared_match)
end
return matches
end
--- Same as get_capture_matches except this will recursively get matches for every language in the tree.
-- @param bufnr The bufnr
-- @param capture_or_fn The capture to get. If a function is provided then that
-- function will be used to resolve both the capture and query argument.
-- The function can return `nil` to ignore that tree.
-- @param query_type The query to get the capture from. This is ignore if a function is provided
-- for the captuer argument.
function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type)
local type_fn = type(capture_or_fn) == 'function'
and capture_or_fn
or function()
return capture_or_fn, query_type
end
local parser = parsers.get_parser(bufnr)
local matches = {}
if parser then
parser:for_each_tree(function(tree, lang_tree)
local lang = lang_tree:lang()
local capture, type_ = type_fn(lang, tree, lang_tree)
if capture then
vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang))
end
end)
end
return matches
end
return M