refactor: separate query to utils/query

This commit is contained in:
Thomas Vigouroux 2020-06-28 21:28:02 +02:00
parent a2ba854001
commit 42e4c625b6
6 changed files with 219 additions and 165 deletions

View file

@ -5,6 +5,7 @@ local queries = require'nvim-treesitter.query'
local locals = require'nvim-treesitter.locals'
local highlight = require'nvim-treesitter.highlight'
local parsers = require'nvim-treesitter.parsers'
local qutils = require'nvim-treesitter.utils.query'
local health_start = vim.fn["health#report_start"]
local health_ok = vim.fn['health#report_ok']
@ -36,7 +37,7 @@ local function install_health()
end
local function highlight_health(lang)
if not queries.get_query(lang, "highlights") then
if not qutils.get_query(lang, "highlights") then
health_warn("No `highlights.scm` query found for " .. lang, {
"Open an issue at https://github.com/nvim-treesitter/nvim-treesitter"
})
@ -46,7 +47,7 @@ local function highlight_health(lang)
end
local function locals_health(lang)
if not queries.get_query(lang, "locals") then
if not qutils.get_query(lang, "locals") then
health_warn("No `locals.scm` query found for " .. lang, {
"Open an issue at https://github.com/nvim-treesitter/nvim-treesitter"
})

View file

@ -1,7 +1,7 @@
local api = vim.api
local ts = vim.treesitter
local queries = require'nvim-treesitter.query'
local queries = require'nvim-treesitter.utils.query'
local parsers = require'nvim-treesitter.parsers'
local M = {

View file

@ -4,8 +4,9 @@
local api = vim.api
local ts = vim.treesitter
local queries = require'nvim-treesitter.query'
local parsers = require'nvim-treesitter.parsers'
local qutils = require'nvim-treesitter.utils.query'
local tutils = require'nvim-treesitter.ts_utils'
local M = {
locals = {}
@ -15,7 +16,7 @@ function M.collect_locals(bufnr)
local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft"))
if not lang then return end
local query = queries.get_query(lang, 'locals')
local query = qutils.get_query(lang, 'locals')
if not query then return end
local parser = parsers.get_parser(bufnr, lang)
@ -26,7 +27,7 @@ function M.collect_locals(bufnr)
local locals = {}
for prepared_match in queries.iter_prepared_matches(query, root, bufnr, start_row, end_row) do
for prepared_match in qutils.iter_prepared_matches(function(_) return true end, query, root, bufnr, start_row, end_row) do
table.insert(locals, prepared_match)
end
@ -89,9 +90,12 @@ function M.get_references(bufnr)
return refs
end
-- Is and is not predicates
function M.is(node, deftype, bufnr)
for def in M.get_definitions(bufnr) do
if def[deftype] and def[deftype].node == node then
if def[deftype]
and tutils.get_node_text(def[deftype].node, bufnr) == tutils.get_node_text(node, bufnr)
then
return true
end
end
@ -99,4 +103,101 @@ function M.is(node, deftype, bufnr)
return false
end
-- Some utils
function M.parent_scope(node, cursor_pos)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local row = cursor_pos.row
local col = cursor_pos.col
local iter_node = node
while iter_node ~= nil do
local row_, col_ = iter_node:start()
if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then
return iter_node
end
iter_node = iter_node:parent()
end
end
function M.containing_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local iter_node = node
while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
iter_node = iter_node:parent()
end
return iter_node or node
end
function M.nested_scope(node, cursor_pos)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local row = cursor_pos.row
local col = cursor_pos.col
local scope = M.containing_scope(node)
for _, child in ipairs(M.get_named_children(scope)) do
local row_, col_ = child:start()
if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
return child
end
end
end
function M.next_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local scope = M.containing_scope(node)
local parent = scope:parent()
if not parent then return end
local is_prev = true
for _, child in ipairs(M.get_named_children(parent)) do
if child == scope then
is_prev = false
elseif not is_prev and vim.tbl_contains(scopes, child) then
return child
end
end
end
function M.previous_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local scope = M.containing_scope(node)
local parent = scope:parent()
if not parent then return end
local is_prev = true
local children = M.get_named_children(parent)
for i=#children,1,-1 do
if children[i] == scope then
is_prev = false
elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
return children[i]
end
end
end
return M

View file

@ -1,6 +1,7 @@
local api = vim.api
local qutils = require'nvim-treesitter.utils.query'
local ts = vim.treesitter
-- local locals = require'nvim-treesitter.locals'
local locals = require'nvim-treesitter.locals'
local M = {}
@ -73,67 +74,7 @@ function M.get_query(lang, query_name)
end
function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
-- A function that splits a string on '.'
local function split(string)
local t = {}
for str in string.gmatch(string, "([^.]+)") do
table.insert(t, str)
end
return t
end
-- Given a path (i.e. a List(String)) this functions inserts value at path
local function insert_to_path(object, path, value)
local curr_obj = object
for index=1,(#path -1) do
if curr_obj[path[index]] == nil then
curr_obj[path[index]] = {}
end
curr_obj = curr_obj[path[index]]
end
curr_obj[path[#path]] = value
end
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
local function iter()
local pattern, match = matches()
if pattern ~= nil then
local prepared_match = {}
-- Extract capture names from each match
for id, node in pairs(match) do
local name = query.captures[id] -- name of the capture in the query
if name ~= nil then
local path = split(name)
insert_to_path(prepared_match, path, { node=node })
end
end
-- Add some predicates for testing
local preds = query.info.patterns[pattern]
if preds then
for _, pred in pairs(preds) do
if pred[1] == "set!" and type(pred[2]) == "string" then
insert_to_path(prepared_match, split(pred[2]), pred[3])
end
if pred[1] == "is?" and type(pred[3]) == "string" then
if not locals.is(pred[2], pred[3], bufnr) then
return iter() -- We should ignore this one, tail call
end
end
end
end
return prepared_match
end
end
return iter
return qutils.iter_prepared_matches(match_pred, query, qnode, bufnr, start_row, end_row)
end
return M

View file

@ -1,6 +1,5 @@
local api = vim.api
local locals = require'nvim-treesitter.locals'
local parsers = require'nvim-treesitter.parsers'
local M = {}
@ -105,40 +104,6 @@ function M.get_previous_node(node, allow_switch_parents, allow_previous_parent)
return destination_node
end
function M.parent_scope(node, cursor_pos)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local row = cursor_pos.row
local col = cursor_pos.col
local iter_node = node
while iter_node ~= nil do
local row_, col_ = iter_node:start()
if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then
return iter_node
end
iter_node = iter_node:parent()
end
end
function M.containing_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local iter_node = node
while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
iter_node = iter_node:parent()
end
return iter_node or node
end
function M.get_named_children(node)
local nodes = {}
for i=0,node:named_child_count() - 1,1 do
@ -147,67 +112,6 @@ function M.get_named_children(node)
return nodes
end
function M.nested_scope(node, cursor_pos)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local row = cursor_pos.row
local col = cursor_pos.col
local scope = M.containing_scope(node)
for _, child in ipairs(M.get_named_children(scope)) do
local row_, col_ = child:start()
if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
return child
end
end
end
function M.next_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local scope = M.containing_scope(node)
local parent = scope:parent()
if not parent then return end
local is_prev = true
for _, child in ipairs(M.get_named_children(parent)) do
if child == scope then
is_prev = false
elseif not is_prev and vim.tbl_contains(scopes, child) then
return child
end
end
end
function M.previous_scope(node)
local bufnr = api.nvim_get_current_buf()
local scopes = locals.get_scopes(bufnr)
if not node or not scopes then return end
local scope = M.containing_scope(node)
local parent = scope:parent()
if not parent then return end
local is_prev = true
local children = M.get_named_children(parent)
for i=#children,1,-1 do
if children[i] == scope then
is_prev = false
elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
return children[i]
end
end
end
function M.get_node_at_cursor(winnr)
local cursor = api.nvim_win_get_cursor(winnr or 0)
local root = parsers.get_parser().tree:root()

View file

@ -0,0 +1,107 @@
local api = vim.api
local fn = vim.fn
local luv = vim.loop
local ts = vim.treesitter
local M = {}
local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
vim.list_extend(contents, vim.fn.readfile(filename))
end
return table.concat(contents, '\n')
end
-- Some treesitter grammars extend others.
-- We can use that to import the queries of the base language
M.base_language_map = {
cpp = {'c'},
typescript = {'javascript'},
tsx = {'typescript', 'javascript'},
}
function M.get_query(lang, query_name)
local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true)
local query_string = ''
if #query_files > 0 then
query_string = read_query_files(query_files)..query_string
end
for _, base_lang in ipairs(M.base_language_map[lang] or {}) do
local base_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', base_lang, query_name), true)
if base_files and #base_files > 0 then
query_string = read_query_files(base_files)..query_string
end
end
if #query_string > 0 then
return ts.parse_query(lang, query_string)
end
end
-- Given a path (i.e. a List(String)) this functions inserts value at path
function M.insert_to_path(object, path, value)
local curr_obj = object
for index=1,(#path -1) do
if curr_obj[path[index]] == nil then
curr_obj[path[index]] = {}
end
curr_obj = curr_obj[path[index]]
end
curr_obj[path[#path]] = value
end
function M.iter_prepared_matches(match_func, query, qnode, bufnr, start_row, end_row)
-- A function that splits a string on '.'
local function split(string)
local t = {}
for str in string.gmatch(string, "([^.]+)") do
table.insert(t, str)
end
return t
end
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
local function iter()
local pattern, match = matches()
if pattern ~= nil then
local prepared_match = {}
-- Extract capture names from each match
for id, node in pairs(match) do
local name = query.captures[id] -- name of the capture in the query
if name ~= nil then
local path = split(name)
M.insert_to_path(prepared_match, path, { node=node })
end
end
-- Add some predicates for testing
local preds = query.info.patterns[pattern]
if preds then
for _, pred in pairs(preds) do
if not match_func(pred) then
return iter()
end
end
end
return prepared_match
end
end
return iter
end
return M