diff --git a/lua/nvim-treesitter/health.lua b/lua/nvim-treesitter/health.lua index e73ce7fb5..52a2e3fcc 100644 --- a/lua/nvim-treesitter/health.lua +++ b/lua/nvim-treesitter/health.lua @@ -5,6 +5,7 @@ local queries = require'nvim-treesitter.query' local locals = require'nvim-treesitter.locals' local highlight = require'nvim-treesitter.highlight' local parsers = require'nvim-treesitter.parsers' +local qutils = require'nvim-treesitter.utils.query' local health_start = vim.fn["health#report_start"] local health_ok = vim.fn['health#report_ok'] @@ -36,7 +37,7 @@ local function install_health() end local function highlight_health(lang) - if not queries.get_query(lang, "highlights") then + if not qutils.get_query(lang, "highlights") then health_warn("No `highlights.scm` query found for " .. lang, { "Open an issue at https://github.com/nvim-treesitter/nvim-treesitter" }) @@ -46,7 +47,7 @@ local function highlight_health(lang) end local function locals_health(lang) - if not queries.get_query(lang, "locals") then + if not qutils.get_query(lang, "locals") then health_warn("No `locals.scm` query found for " .. lang, { "Open an issue at https://github.com/nvim-treesitter/nvim-treesitter" }) diff --git a/lua/nvim-treesitter/highlight.lua b/lua/nvim-treesitter/highlight.lua index 1b4722435..a30c5c988 100644 --- a/lua/nvim-treesitter/highlight.lua +++ b/lua/nvim-treesitter/highlight.lua @@ -1,7 +1,7 @@ local api = vim.api local ts = vim.treesitter -local queries = require'nvim-treesitter.query' +local queries = require'nvim-treesitter.utils.query' local parsers = require'nvim-treesitter.parsers' local M = { diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index ee2de810a..9086d2262 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -4,8 +4,9 @@ local api = vim.api local ts = vim.treesitter -local queries = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' +local qutils = require'nvim-treesitter.utils.query' +local tutils = require'nvim-treesitter.ts_utils' local M = { locals = {} @@ -15,7 +16,7 @@ function M.collect_locals(bufnr) local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft")) if not lang then return end - local query = queries.get_query(lang, 'locals') + local query = qutils.get_query(lang, 'locals') if not query then return end local parser = parsers.get_parser(bufnr, lang) @@ -26,7 +27,7 @@ function M.collect_locals(bufnr) local locals = {} - for prepared_match in queries.iter_prepared_matches(query, root, bufnr, start_row, end_row) do + for prepared_match in qutils.iter_prepared_matches(function(_) return true end, query, root, bufnr, start_row, end_row) do table.insert(locals, prepared_match) end @@ -89,9 +90,12 @@ function M.get_references(bufnr) return refs end +-- Is and is not predicates function M.is(node, deftype, bufnr) for def in M.get_definitions(bufnr) do - if def[deftype] and def[deftype].node == node then + if def[deftype] + and tutils.get_node_text(def[deftype].node, bufnr) == tutils.get_node_text(node, bufnr) + then return true end end @@ -99,4 +103,101 @@ function M.is(node, deftype, bufnr) return false end +-- Some utils +function M.parent_scope(node, cursor_pos) + local bufnr = api.nvim_get_current_buf() + + local scopes = locals.get_scopes(bufnr) + if not node or not scopes then return end + + local row = cursor_pos.row + local col = cursor_pos.col + local iter_node = node + + while iter_node ~= nil do + local row_, col_ = iter_node:start() + if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then + return iter_node + end + iter_node = iter_node:parent() + end +end + +function M.containing_scope(node) + local bufnr = api.nvim_get_current_buf() + + local scopes = locals.get_scopes(bufnr) + if not node or not scopes then return end + + local iter_node = node + + while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do + iter_node = iter_node:parent() + end + + return iter_node or node +end + + +function M.nested_scope(node, cursor_pos) + local bufnr = api.nvim_get_current_buf() + + local scopes = locals.get_scopes(bufnr) + if not node or not scopes then return end + + local row = cursor_pos.row + local col = cursor_pos.col + local scope = M.containing_scope(node) + + for _, child in ipairs(M.get_named_children(scope)) do + local row_, col_ = child:start() + if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then + return child + end + end +end + +function M.next_scope(node) + local bufnr = api.nvim_get_current_buf() + + local scopes = locals.get_scopes(bufnr) + if not node or not scopes then return end + + local scope = M.containing_scope(node) + + local parent = scope:parent() + if not parent then return end + + local is_prev = true + for _, child in ipairs(M.get_named_children(parent)) do + if child == scope then + is_prev = false + elseif not is_prev and vim.tbl_contains(scopes, child) then + return child + end + end +end + +function M.previous_scope(node) + local bufnr = api.nvim_get_current_buf() + + local scopes = locals.get_scopes(bufnr) + if not node or not scopes then return end + + local scope = M.containing_scope(node) + + local parent = scope:parent() + if not parent then return end + + local is_prev = true + local children = M.get_named_children(parent) + for i=#children,1,-1 do + if children[i] == scope then + is_prev = false + elseif not is_prev and vim.tbl_contains(scopes, children[i]) then + return children[i] + end + end +end + return M diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 358e9f194..17a976beb 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -1,6 +1,7 @@ local api = vim.api +local qutils = require'nvim-treesitter.utils.query' local ts = vim.treesitter --- local locals = require'nvim-treesitter.locals' +local locals = require'nvim-treesitter.locals' local M = {} @@ -73,67 +74,7 @@ function M.get_query(lang, query_name) end function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) - -- A function that splits a string on '.' - local function split(string) - local t = {} - for str in string.gmatch(string, "([^.]+)") do - table.insert(t, str) - end - - return t - end - - -- Given a path (i.e. a List(String)) this functions inserts value at path - local function insert_to_path(object, path, value) - local curr_obj = object - - for index=1,(#path -1) do - if curr_obj[path[index]] == nil then - curr_obj[path[index]] = {} - end - - curr_obj = curr_obj[path[index]] - end - - curr_obj[path[#path]] = value - end - - local matches = query:iter_matches(qnode, bufnr, start_row, end_row) - - local function iter() - local pattern, match = matches() - if pattern ~= nil then - local prepared_match = {} - - -- Extract capture names from each match - for id, node in pairs(match) do - local name = query.captures[id] -- name of the capture in the query - if name ~= nil then - local path = split(name) - insert_to_path(prepared_match, path, { node=node }) - end - end - - -- Add some predicates for testing - local preds = query.info.patterns[pattern] - if preds then - for _, pred in pairs(preds) do - if pred[1] == "set!" and type(pred[2]) == "string" then - insert_to_path(prepared_match, split(pred[2]), pred[3]) - end - if pred[1] == "is?" and type(pred[3]) == "string" then - if not locals.is(pred[2], pred[3], bufnr) then - return iter() -- We should ignore this one, tail call - end - end - end - end - - return prepared_match - end - end - - return iter + return qutils.iter_prepared_matches(match_pred, query, qnode, bufnr, start_row, end_row) end return M diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index 06f92c885..01a4ef2fe 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -1,6 +1,5 @@ local api = vim.api -local locals = require'nvim-treesitter.locals' local parsers = require'nvim-treesitter.parsers' local M = {} @@ -105,40 +104,6 @@ function M.get_previous_node(node, allow_switch_parents, allow_previous_parent) return destination_node end -function M.parent_scope(node, cursor_pos) - local bufnr = api.nvim_get_current_buf() - - local scopes = locals.get_scopes(bufnr) - if not node or not scopes then return end - - local row = cursor_pos.row - local col = cursor_pos.col - local iter_node = node - - while iter_node ~= nil do - local row_, col_ = iter_node:start() - if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then - return iter_node - end - iter_node = iter_node:parent() - end -end - -function M.containing_scope(node) - local bufnr = api.nvim_get_current_buf() - - local scopes = locals.get_scopes(bufnr) - if not node or not scopes then return end - - local iter_node = node - - while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do - iter_node = iter_node:parent() - end - - return iter_node or node -end - function M.get_named_children(node) local nodes = {} for i=0,node:named_child_count() - 1,1 do @@ -147,67 +112,6 @@ function M.get_named_children(node) return nodes end -function M.nested_scope(node, cursor_pos) - local bufnr = api.nvim_get_current_buf() - - local scopes = locals.get_scopes(bufnr) - if not node or not scopes then return end - - local row = cursor_pos.row - local col = cursor_pos.col - local scope = M.containing_scope(node) - - for _, child in ipairs(M.get_named_children(scope)) do - local row_, col_ = child:start() - if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then - return child - end - end -end - -function M.next_scope(node) - local bufnr = api.nvim_get_current_buf() - - local scopes = locals.get_scopes(bufnr) - if not node or not scopes then return end - - local scope = M.containing_scope(node) - - local parent = scope:parent() - if not parent then return end - - local is_prev = true - for _, child in ipairs(M.get_named_children(parent)) do - if child == scope then - is_prev = false - elseif not is_prev and vim.tbl_contains(scopes, child) then - return child - end - end -end - -function M.previous_scope(node) - local bufnr = api.nvim_get_current_buf() - - local scopes = locals.get_scopes(bufnr) - if not node or not scopes then return end - - local scope = M.containing_scope(node) - - local parent = scope:parent() - if not parent then return end - - local is_prev = true - local children = M.get_named_children(parent) - for i=#children,1,-1 do - if children[i] == scope then - is_prev = false - elseif not is_prev and vim.tbl_contains(scopes, children[i]) then - return children[i] - end - end -end - function M.get_node_at_cursor(winnr) local cursor = api.nvim_win_get_cursor(winnr or 0) local root = parsers.get_parser().tree:root() diff --git a/lua/nvim-treesitter/utils/query.lua b/lua/nvim-treesitter/utils/query.lua new file mode 100644 index 000000000..5b6254df8 --- /dev/null +++ b/lua/nvim-treesitter/utils/query.lua @@ -0,0 +1,107 @@ +local api = vim.api +local fn = vim.fn +local luv = vim.loop +local ts = vim.treesitter + +local M = {} + +local function read_query_files(filenames) + local contents = {} + + for _,filename in ipairs(filenames) do + vim.list_extend(contents, vim.fn.readfile(filename)) + end + + return table.concat(contents, '\n') +end + +-- Some treesitter grammars extend others. +-- We can use that to import the queries of the base language +M.base_language_map = { + cpp = {'c'}, + typescript = {'javascript'}, + tsx = {'typescript', 'javascript'}, +} + +function M.get_query(lang, query_name) + local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true) + local query_string = '' + + if #query_files > 0 then + query_string = read_query_files(query_files)..query_string + end + + for _, base_lang in ipairs(M.base_language_map[lang] or {}) do + local base_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', base_lang, query_name), true) + if base_files and #base_files > 0 then + query_string = read_query_files(base_files)..query_string + end + end + + if #query_string > 0 then + return ts.parse_query(lang, query_string) + end +end + +-- Given a path (i.e. a List(String)) this functions inserts value at path +function M.insert_to_path(object, path, value) + local curr_obj = object + + for index=1,(#path -1) do + if curr_obj[path[index]] == nil then + curr_obj[path[index]] = {} + end + + curr_obj = curr_obj[path[index]] + end + + curr_obj[path[#path]] = value +end + + +function M.iter_prepared_matches(match_func, query, qnode, bufnr, start_row, end_row) + -- A function that splits a string on '.' + local function split(string) + local t = {} + for str in string.gmatch(string, "([^.]+)") do + table.insert(t, str) + end + + return t + end + + + local matches = query:iter_matches(qnode, bufnr, start_row, end_row) + + local function iter() + local pattern, match = matches() + if pattern ~= nil then + local prepared_match = {} + + -- Extract capture names from each match + for id, node in pairs(match) do + local name = query.captures[id] -- name of the capture in the query + if name ~= nil then + local path = split(name) + M.insert_to_path(prepared_match, path, { node=node }) + end + end + + -- Add some predicates for testing + local preds = query.info.patterns[pattern] + if preds then + for _, pred in pairs(preds) do + if not match_func(pred) then + return iter() + end + end + end + + return prepared_match + end + end + + return iter +end + +return M