mirror of
https://github.com/nvim-treesitter/nvim-treesitter.git
synced 2026-07-01 19:17:02 -04:00
refactor: separate query to utils/query
This commit is contained in:
parent
a2ba854001
commit
42e4c625b6
6 changed files with 219 additions and 165 deletions
|
|
@ -5,6 +5,7 @@ local queries = require'nvim-treesitter.query'
|
|||
local locals = require'nvim-treesitter.locals'
|
||||
local highlight = require'nvim-treesitter.highlight'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
local qutils = require'nvim-treesitter.utils.query'
|
||||
|
||||
local health_start = vim.fn["health#report_start"]
|
||||
local health_ok = vim.fn['health#report_ok']
|
||||
|
|
@ -36,7 +37,7 @@ local function install_health()
|
|||
end
|
||||
|
||||
local function highlight_health(lang)
|
||||
if not queries.get_query(lang, "highlights") then
|
||||
if not qutils.get_query(lang, "highlights") then
|
||||
health_warn("No `highlights.scm` query found for " .. lang, {
|
||||
"Open an issue at https://github.com/nvim-treesitter/nvim-treesitter"
|
||||
})
|
||||
|
|
@ -46,7 +47,7 @@ local function highlight_health(lang)
|
|||
end
|
||||
|
||||
local function locals_health(lang)
|
||||
if not queries.get_query(lang, "locals") then
|
||||
if not qutils.get_query(lang, "locals") then
|
||||
health_warn("No `locals.scm` query found for " .. lang, {
|
||||
"Open an issue at https://github.com/nvim-treesitter/nvim-treesitter"
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
local api = vim.api
|
||||
local ts = vim.treesitter
|
||||
|
||||
local queries = require'nvim-treesitter.query'
|
||||
local queries = require'nvim-treesitter.utils.query'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
|
||||
local M = {
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@
|
|||
local api = vim.api
|
||||
local ts = vim.treesitter
|
||||
|
||||
local queries = require'nvim-treesitter.query'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
local qutils = require'nvim-treesitter.utils.query'
|
||||
local tutils = require'nvim-treesitter.ts_utils'
|
||||
|
||||
local M = {
|
||||
locals = {}
|
||||
|
|
@ -15,7 +16,7 @@ function M.collect_locals(bufnr)
|
|||
local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft"))
|
||||
if not lang then return end
|
||||
|
||||
local query = queries.get_query(lang, 'locals')
|
||||
local query = qutils.get_query(lang, 'locals')
|
||||
if not query then return end
|
||||
|
||||
local parser = parsers.get_parser(bufnr, lang)
|
||||
|
|
@ -26,7 +27,7 @@ function M.collect_locals(bufnr)
|
|||
|
||||
local locals = {}
|
||||
|
||||
for prepared_match in queries.iter_prepared_matches(query, root, bufnr, start_row, end_row) do
|
||||
for prepared_match in qutils.iter_prepared_matches(function(_) return true end, query, root, bufnr, start_row, end_row) do
|
||||
table.insert(locals, prepared_match)
|
||||
end
|
||||
|
||||
|
|
@ -89,9 +90,12 @@ function M.get_references(bufnr)
|
|||
return refs
|
||||
end
|
||||
|
||||
-- Is and is not predicates
|
||||
function M.is(node, deftype, bufnr)
|
||||
for def in M.get_definitions(bufnr) do
|
||||
if def[deftype] and def[deftype].node == node then
|
||||
if def[deftype]
|
||||
and tutils.get_node_text(def[deftype].node, bufnr) == tutils.get_node_text(node, bufnr)
|
||||
then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
|
@ -99,4 +103,101 @@ function M.is(node, deftype, bufnr)
|
|||
return false
|
||||
end
|
||||
|
||||
-- Some utils
|
||||
function M.parent_scope(node, cursor_pos)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local row = cursor_pos.row
|
||||
local col = cursor_pos.col
|
||||
local iter_node = node
|
||||
|
||||
while iter_node ~= nil do
|
||||
local row_, col_ = iter_node:start()
|
||||
if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then
|
||||
return iter_node
|
||||
end
|
||||
iter_node = iter_node:parent()
|
||||
end
|
||||
end
|
||||
|
||||
function M.containing_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local iter_node = node
|
||||
|
||||
while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
|
||||
iter_node = iter_node:parent()
|
||||
end
|
||||
|
||||
return iter_node or node
|
||||
end
|
||||
|
||||
|
||||
function M.nested_scope(node, cursor_pos)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local row = cursor_pos.row
|
||||
local col = cursor_pos.col
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
for _, child in ipairs(M.get_named_children(scope)) do
|
||||
local row_, col_ = child:start()
|
||||
if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
|
||||
return child
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M.next_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
local parent = scope:parent()
|
||||
if not parent then return end
|
||||
|
||||
local is_prev = true
|
||||
for _, child in ipairs(M.get_named_children(parent)) do
|
||||
if child == scope then
|
||||
is_prev = false
|
||||
elseif not is_prev and vim.tbl_contains(scopes, child) then
|
||||
return child
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M.previous_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
local parent = scope:parent()
|
||||
if not parent then return end
|
||||
|
||||
local is_prev = true
|
||||
local children = M.get_named_children(parent)
|
||||
for i=#children,1,-1 do
|
||||
if children[i] == scope then
|
||||
is_prev = false
|
||||
elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
|
||||
return children[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
local api = vim.api
|
||||
local qutils = require'nvim-treesitter.utils.query'
|
||||
local ts = vim.treesitter
|
||||
-- local locals = require'nvim-treesitter.locals'
|
||||
local locals = require'nvim-treesitter.locals'
|
||||
|
||||
local M = {}
|
||||
|
||||
|
|
@ -73,67 +74,7 @@ function M.get_query(lang, query_name)
|
|||
end
|
||||
|
||||
function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
|
||||
-- A function that splits a string on '.'
|
||||
local function split(string)
|
||||
local t = {}
|
||||
for str in string.gmatch(string, "([^.]+)") do
|
||||
table.insert(t, str)
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
||||
|
||||
-- Given a path (i.e. a List(String)) this functions inserts value at path
|
||||
local function insert_to_path(object, path, value)
|
||||
local curr_obj = object
|
||||
|
||||
for index=1,(#path -1) do
|
||||
if curr_obj[path[index]] == nil then
|
||||
curr_obj[path[index]] = {}
|
||||
end
|
||||
|
||||
curr_obj = curr_obj[path[index]]
|
||||
end
|
||||
|
||||
curr_obj[path[#path]] = value
|
||||
end
|
||||
|
||||
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
|
||||
|
||||
local function iter()
|
||||
local pattern, match = matches()
|
||||
if pattern ~= nil then
|
||||
local prepared_match = {}
|
||||
|
||||
-- Extract capture names from each match
|
||||
for id, node in pairs(match) do
|
||||
local name = query.captures[id] -- name of the capture in the query
|
||||
if name ~= nil then
|
||||
local path = split(name)
|
||||
insert_to_path(prepared_match, path, { node=node })
|
||||
end
|
||||
end
|
||||
|
||||
-- Add some predicates for testing
|
||||
local preds = query.info.patterns[pattern]
|
||||
if preds then
|
||||
for _, pred in pairs(preds) do
|
||||
if pred[1] == "set!" and type(pred[2]) == "string" then
|
||||
insert_to_path(prepared_match, split(pred[2]), pred[3])
|
||||
end
|
||||
if pred[1] == "is?" and type(pred[3]) == "string" then
|
||||
if not locals.is(pred[2], pred[3], bufnr) then
|
||||
return iter() -- We should ignore this one, tail call
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return prepared_match
|
||||
end
|
||||
end
|
||||
|
||||
return iter
|
||||
return qutils.iter_prepared_matches(match_pred, query, qnode, bufnr, start_row, end_row)
|
||||
end
|
||||
|
||||
return M
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
local api = vim.api
|
||||
|
||||
local locals = require'nvim-treesitter.locals'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
|
||||
local M = {}
|
||||
|
|
@ -105,40 +104,6 @@ function M.get_previous_node(node, allow_switch_parents, allow_previous_parent)
|
|||
return destination_node
|
||||
end
|
||||
|
||||
function M.parent_scope(node, cursor_pos)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local row = cursor_pos.row
|
||||
local col = cursor_pos.col
|
||||
local iter_node = node
|
||||
|
||||
while iter_node ~= nil do
|
||||
local row_, col_ = iter_node:start()
|
||||
if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then
|
||||
return iter_node
|
||||
end
|
||||
iter_node = iter_node:parent()
|
||||
end
|
||||
end
|
||||
|
||||
function M.containing_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local iter_node = node
|
||||
|
||||
while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
|
||||
iter_node = iter_node:parent()
|
||||
end
|
||||
|
||||
return iter_node or node
|
||||
end
|
||||
|
||||
function M.get_named_children(node)
|
||||
local nodes = {}
|
||||
for i=0,node:named_child_count() - 1,1 do
|
||||
|
|
@ -147,67 +112,6 @@ function M.get_named_children(node)
|
|||
return nodes
|
||||
end
|
||||
|
||||
function M.nested_scope(node, cursor_pos)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local row = cursor_pos.row
|
||||
local col = cursor_pos.col
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
for _, child in ipairs(M.get_named_children(scope)) do
|
||||
local row_, col_ = child:start()
|
||||
if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
|
||||
return child
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M.next_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
local parent = scope:parent()
|
||||
if not parent then return end
|
||||
|
||||
local is_prev = true
|
||||
for _, child in ipairs(M.get_named_children(parent)) do
|
||||
if child == scope then
|
||||
is_prev = false
|
||||
elseif not is_prev and vim.tbl_contains(scopes, child) then
|
||||
return child
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M.previous_scope(node)
|
||||
local bufnr = api.nvim_get_current_buf()
|
||||
|
||||
local scopes = locals.get_scopes(bufnr)
|
||||
if not node or not scopes then return end
|
||||
|
||||
local scope = M.containing_scope(node)
|
||||
|
||||
local parent = scope:parent()
|
||||
if not parent then return end
|
||||
|
||||
local is_prev = true
|
||||
local children = M.get_named_children(parent)
|
||||
for i=#children,1,-1 do
|
||||
if children[i] == scope then
|
||||
is_prev = false
|
||||
elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
|
||||
return children[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M.get_node_at_cursor(winnr)
|
||||
local cursor = api.nvim_win_get_cursor(winnr or 0)
|
||||
local root = parsers.get_parser().tree:root()
|
||||
|
|
|
|||
107
lua/nvim-treesitter/utils/query.lua
Normal file
107
lua/nvim-treesitter/utils/query.lua
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
local api = vim.api
|
||||
local fn = vim.fn
|
||||
local luv = vim.loop
|
||||
local ts = vim.treesitter
|
||||
|
||||
local M = {}
|
||||
|
||||
local function read_query_files(filenames)
|
||||
local contents = {}
|
||||
|
||||
for _,filename in ipairs(filenames) do
|
||||
vim.list_extend(contents, vim.fn.readfile(filename))
|
||||
end
|
||||
|
||||
return table.concat(contents, '\n')
|
||||
end
|
||||
|
||||
-- Some treesitter grammars extend others.
|
||||
-- We can use that to import the queries of the base language
|
||||
M.base_language_map = {
|
||||
cpp = {'c'},
|
||||
typescript = {'javascript'},
|
||||
tsx = {'typescript', 'javascript'},
|
||||
}
|
||||
|
||||
function M.get_query(lang, query_name)
|
||||
local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true)
|
||||
local query_string = ''
|
||||
|
||||
if #query_files > 0 then
|
||||
query_string = read_query_files(query_files)..query_string
|
||||
end
|
||||
|
||||
for _, base_lang in ipairs(M.base_language_map[lang] or {}) do
|
||||
local base_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', base_lang, query_name), true)
|
||||
if base_files and #base_files > 0 then
|
||||
query_string = read_query_files(base_files)..query_string
|
||||
end
|
||||
end
|
||||
|
||||
if #query_string > 0 then
|
||||
return ts.parse_query(lang, query_string)
|
||||
end
|
||||
end
|
||||
|
||||
-- Given a path (i.e. a List(String)) this functions inserts value at path
|
||||
function M.insert_to_path(object, path, value)
|
||||
local curr_obj = object
|
||||
|
||||
for index=1,(#path -1) do
|
||||
if curr_obj[path[index]] == nil then
|
||||
curr_obj[path[index]] = {}
|
||||
end
|
||||
|
||||
curr_obj = curr_obj[path[index]]
|
||||
end
|
||||
|
||||
curr_obj[path[#path]] = value
|
||||
end
|
||||
|
||||
|
||||
function M.iter_prepared_matches(match_func, query, qnode, bufnr, start_row, end_row)
|
||||
-- A function that splits a string on '.'
|
||||
local function split(string)
|
||||
local t = {}
|
||||
for str in string.gmatch(string, "([^.]+)") do
|
||||
table.insert(t, str)
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
||||
|
||||
|
||||
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
|
||||
|
||||
local function iter()
|
||||
local pattern, match = matches()
|
||||
if pattern ~= nil then
|
||||
local prepared_match = {}
|
||||
|
||||
-- Extract capture names from each match
|
||||
for id, node in pairs(match) do
|
||||
local name = query.captures[id] -- name of the capture in the query
|
||||
if name ~= nil then
|
||||
local path = split(name)
|
||||
M.insert_to_path(prepared_match, path, { node=node })
|
||||
end
|
||||
end
|
||||
|
||||
-- Add some predicates for testing
|
||||
local preds = query.info.patterns[pattern]
|
||||
if preds then
|
||||
for _, pred in pairs(preds) do
|
||||
if not match_func(pred) then
|
||||
return iter()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return prepared_match
|
||||
end
|
||||
end
|
||||
|
||||
return iter
|
||||
end
|
||||
|
||||
return M
|
||||
Loading…
Add table
Add a link
Reference in a new issue