feat!: drop modules, general refactor and cleanup

This commit is contained in:
Christian Clason 2023-06-12 09:54:30 -06:00
parent c13e28f894
commit 2c8f2f2fad
829 changed files with 4905 additions and 8010 deletions

View file

@ -1,71 +0,0 @@
local api = vim.api
local M = {}
-- Creates a cache table for buffers keyed by a type name.
-- Cache entries attach to the buffer and cleanup entries
-- as buffers are detached.
function M.create_buffer_cache()
local cache = {}
---@type table<integer, table<string, any>>
local items = setmetatable({}, {
__index = function(tbl, key)
rawset(tbl, key, {})
return rawget(tbl, key)
end,
})
---@type table<integer, boolean>
local loaded_buffers = {}
---@param type_name string
---@param bufnr integer
---@param value any
function cache.set(type_name, bufnr, value)
if not loaded_buffers[bufnr] then
loaded_buffers[bufnr] = true
-- Clean up the cache if the buffer is detached
-- to avoid memory leaks
api.nvim_buf_attach(bufnr, false, {
on_detach = function()
cache.clear_buffer(bufnr)
loaded_buffers[bufnr] = nil
return true
end,
on_reload = function() end, -- this is needed to prevent on_detach being called on buffer reload
})
end
items[bufnr][type_name] = value
end
---@param type_name string
---@param bufnr integer
---@return any
function cache.get(type_name, bufnr)
return items[bufnr][type_name]
end
---@param type_name string
---@param bufnr integer
---@return boolean
function cache.has(type_name, bufnr)
return cache.get(type_name, bufnr) ~= nil
end
---@param type_name string
---@param bufnr integer
function cache.remove(type_name, bufnr)
items[bufnr][type_name] = nil
end
---@param bufnr integer
function cache.clear_buffer(bufnr)
items[bufnr] = nil
end
return cache
end
return M

View file

@ -1,27 +0,0 @@
-- Shim module to address deprecations across nvim versions
local ts = vim.treesitter
local tsq = ts.query
local M = {}
function M.get_query_files(lang, query_group, is_included)
return (tsq.get_files or tsq.get_query_files)(lang, query_group, is_included)
end
function M.get_query(lang, query_name)
return (tsq.get or tsq.get_query)(lang, query_name)
end
function M.parse_query(lang, query)
return (tsq.parse or tsq.parse_query)(lang, query)
end
function M.get_range(node, source, metadata)
return (ts.get_range or tsq.get_range)(node, source, metadata)
end
function M.get_node_text(node, bufnr)
return (ts.get_node_text or tsq.get_node_text)(node, bufnr)
end
return M

View file

@ -0,0 +1,144 @@
local utils = require('nvim-treesitter.utils')
local M = {}
---@class TSConfig
---@field sync_install boolean
---@field auto_install boolean
---@field ensure_install string[]
---@field ignore_install string[]
---@field install_dir string
---@type TSConfig
local config = {
sync_install = false,
auto_install = false,
ensure_install = {},
ignore_install = {},
install_dir = utils.join_path(vim.fn.stdpath('data'), 'site'),
}
---Setup call for users to override configuration configurations.
---@param user_data TSConfig|nil user configuration table
function M.setup(user_data)
if user_data then
if user_data.install_dir then
user_data.install_dir = vim.fs.normalize(user_data.install_dir)
--TODO(clason): insert after/before site, or leave to user?
vim.opt.runtimepath:append(user_data.install_dir)
end
config = vim.tbl_deep_extend('force', config, user_data)
end
if config.auto_install then
vim.api.nvim_create_autocmd('FileType', {
callback = function(args)
local ft = vim.bo[args.buf].filetype
local lang = vim.treesitter.language.get_lang(ft) or ft
if
require('nvim-treesitter.parsers').configs[lang]
and not vim.list_contains(M.installed_parsers(), lang)
and not vim.list_contains(config.ignore_install, lang)
then
require('nvim-treesitter.install').install(lang)
end
end,
})
end
if #config.ensure_install > 0 then
local to_install = M.norm_languages(config.ensure_install, { ignored = true, installed = true })
if #to_install > 0 then
require('nvim-treesitter.install').install(to_install, {
with_sync = config.sync_install,
})
end
end
end
-- Returns the install path for parsers, parser info, and queries.
-- If the specified directory does not exist, it is created.
---@param dir_name string
---@return string
function M.get_install_dir(dir_name)
local dir = utils.join_path(config.install_dir, dir_name)
if not vim.loop.fs_stat(dir) then
local ok, error = pcall(vim.fn.mkdir, dir, 'p', '0755')
if not ok then
vim.notify(error, vim.log.levels.ERROR)
end
end
return dir
end
---@return string[]
function M.installed_parsers()
local install_dir = M.get_install_dir('parser')
local installed = {} --- @type string[]
for f in vim.fs.dir(install_dir) do
local lang = assert(f:match('(.*)%..*'))
installed[#installed + 1] = lang
end
return installed
end
---Normalize languages
---@param languages? string[]|string
---@param skip? table
---@return string[]
function M.norm_languages(languages, skip)
if not languages then
return {}
end
local parsers = require('nvim-treesitter.parsers')
-- Turn into table
if type(languages) == 'string' then
languages = { languages }
end
if vim.list_contains(languages, 'all') then
if skip and skip.missing then
return M.installed_parsers()
end
languages = parsers.get_available()
end
for i, tier in ipairs(parsers.tiers) do
if vim.list_contains(languages, tier) then
languages = vim.iter.filter(function(l)
return l ~= tier
end, languages)
vim.list_extend(languages, parsers.get_available(i))
end
end
if skip and skip.ignored then
local ignored = config.ignore_install
languages = vim.iter.filter(function(v)
return not vim.list_contains(ignored, v)
end, languages)
end
if skip and skip.installed then
local installed = M.installed_parsers()
languages = vim.iter.filter(function(v)
return not vim.list_contains(installed, v)
end, languages)
end
if skip and skip.missing then
local installed = M.installed_parsers()
languages = vim.iter.filter(function(v)
return vim.list_contains(installed, v)
end, languages)
end
return languages
end
return M

View file

@ -1,616 +0,0 @@
local api = vim.api
local queries = require "nvim-treesitter.query"
local ts = require "nvim-treesitter.compat"
local parsers = require "nvim-treesitter.parsers"
local utils = require "nvim-treesitter.utils"
local caching = require "nvim-treesitter.caching"
local M = {}
---@class TSConfig
---@field modules {[string]:TSModule}
---@field sync_install boolean
---@field ensure_installed string[]|string
---@field ignore_install string[]
---@field auto_install boolean
---@field parser_install_dir string|nil
---@type TSConfig
local config = {
modules = {},
sync_install = false,
ensure_installed = {},
auto_install = false,
ignore_install = {},
parser_install_dir = nil,
}
-- List of modules that need to be setup on initialization.
---@type TSModule[][]
local queued_modules_defs = {}
-- Whether we've initialized the plugin yet.
local is_initialized = false
---@class TSModule
---@field module_path string
---@field enable boolean|string[]|function(string): boolean
---@field disable boolean|string[]|function(string): boolean
---@field keymaps table<string, string>
---@field is_supported function(string): boolean
---@field attach function(string)
---@field detach function(string)
---@field enabled_buffers table<integer, boolean>
---@field additional_vim_regex_highlighting boolean|string[]
---@type {[string]: TSModule}
local builtin_modules = {
highlight = {
module_path = "nvim-treesitter.highlight",
-- @deprecated: use `highlight.set_custom_captures` instead
custom_captures = {},
enable = false,
is_supported = function(lang)
return queries.has_highlights(lang)
end,
additional_vim_regex_highlighting = false,
},
incremental_selection = {
module_path = "nvim-treesitter.incremental_selection",
enable = false,
keymaps = {
init_selection = "gnn", -- set to `false` to disable one of the mappings
node_incremental = "grn",
scope_incremental = "grc",
node_decremental = "grm",
},
is_supported = function()
return true
end,
},
indent = {
module_path = "nvim-treesitter.indent",
enable = false,
is_supported = queries.has_indents,
},
}
local attached_buffers_by_module = caching.create_buffer_cache()
---Resolves a module by requiring the `module_path` or using the module definition.
---@param mod_name string
---@return TSModule|nil
local function resolve_module(mod_name)
local config_mod = M.get_module(mod_name)
if not config_mod then
return
end
if type(config_mod.attach) == "function" and type(config_mod.detach) == "function" then
return config_mod
elseif type(config_mod.module_path) == "string" then
return require(config_mod.module_path)
end
end
---Enables and attaches the module to a buffer for lang.
---@param mod string path to module
---@param bufnr integer|nil buffer number, defaults to current buffer
---@param lang string|nil language, defaults to current language
local function enable_module(mod, bufnr, lang)
local module = M.get_module(mod)
if not module then
return
end
bufnr = bufnr or api.nvim_get_current_buf()
lang = lang or parsers.get_buf_lang(bufnr)
if not module.enable then
if module.enabled_buffers then
module.enabled_buffers[bufnr] = true
else
module.enabled_buffers = { [bufnr] = true }
end
end
M.attach_module(mod, bufnr, lang)
end
---Enables autocomands for the module.
---After the module is loaded `loaded` will be set to true for the module.
---@param mod string path to module
local function enable_mod_conf_autocmd(mod)
local config_mod = M.get_module(mod)
if not config_mod or config_mod.loaded then
return
end
api.nvim_create_autocmd("FileType", {
group = api.nvim_create_augroup("NvimTreesitter-" .. mod, {}),
callback = function(args)
require("nvim-treesitter.configs").reattach_module(mod, args.buf)
end,
desc = "Reattach module",
})
config_mod.loaded = true
end
---Enables the module globally and for all current buffers.
---After enabled, `enable` will be set to true for the module.
---@param mod string path to module
local function enable_all(mod)
local config_mod = M.get_module(mod)
if not config_mod then
return
end
enable_mod_conf_autocmd(mod)
config_mod.enable = true
config_mod.enabled_buffers = nil
for _, bufnr in pairs(api.nvim_list_bufs()) do
enable_module(mod, bufnr)
end
end
---Disables and detaches the module for a buffer.
---@param mod string path to module
---@param bufnr integer buffer number, defaults to current buffer
local function disable_module(mod, bufnr)
local module = M.get_module(mod)
if not module then
return
end
bufnr = bufnr or api.nvim_get_current_buf()
if module.enabled_buffers then
module.enabled_buffers[bufnr] = false
end
M.detach_module(mod, bufnr)
end
---Disables autocomands for the module.
---After the module is unloaded `loaded` will be set to false for the module.
---@param mod string path to module
local function disable_mod_conf_autocmd(mod)
local config_mod = M.get_module(mod)
if not config_mod or not config_mod.loaded then
return
end
api.nvim_clear_autocmds { event = "FileType", group = "NvimTreesitter-" .. mod }
config_mod.loaded = false
end
---Disables the module globally and for all current buffers.
---After disabled, `enable` will be set to false for the module.
---@param mod string path to module
local function disable_all(mod)
local config_mod = M.get_module(mod)
if not config_mod then
return
end
config_mod.enabled_buffers = nil
disable_mod_conf_autocmd(mod)
config_mod.enable = false
for _, bufnr in pairs(api.nvim_list_bufs()) do
disable_module(mod, bufnr)
end
end
---Toggles a module for a buffer
---@param mod string path to module
---@param bufnr integer buffer number, defaults to current buffer
---@param lang string language, defaults to current language
local function toggle_module(mod, bufnr, lang)
bufnr = bufnr or api.nvim_get_current_buf()
lang = lang or parsers.get_buf_lang(bufnr)
if attached_buffers_by_module.has(mod, bufnr) then
disable_module(mod, bufnr)
else
enable_module(mod, bufnr, lang)
end
end
-- Toggles the module globally and for all current buffers.
-- @param mod path to module
local function toggle_all(mod)
local config_mod = M.get_module(mod)
if not config_mod then
return
end
if config_mod.enable then
disable_all(mod)
else
enable_all(mod)
end
end
---Recurses through all modules including submodules
---@param accumulator function called for each module
---@param root {[string]: TSModule}|nil root configuration table to start at
---@param path string|nil prefix path
local function recurse_modules(accumulator, root, path)
root = root or config.modules
for name, module in pairs(root) do
local new_path = path and (path .. "." .. name) or name
if M.is_module(module) then
accumulator(name, module, new_path, root)
elseif type(module) == "table" then
recurse_modules(accumulator, module, new_path)
end
end
end
-- Shows current configuration of all nvim-treesitter modules
---@param process_function function used as the `process` parameter
--- for vim.inspect (https://github.com/kikito/inspect.lua#optionsprocess)
local function config_info(process_function)
process_function = process_function
or function(item, path)
if path[#path] == vim.inspect.METATABLE then
return
end
if path[#path] == "is_supported" then
return
end
return item
end
print(vim.inspect(config, { process = process_function }))
end
---@param query_group string
---@param lang string
function M.edit_query_file(query_group, lang)
lang = lang or parsers.get_buf_lang()
local files = ts.get_query_files(lang, query_group, true)
if #files == 0 then
utils.notify "No query file found! Creating a new one!"
M.edit_query_file_user_after(query_group, lang)
elseif #files == 1 then
vim.cmd(":edit " .. files[1])
else
vim.ui.select(files, { prompt = "Select a file:" }, function(file)
if file then
vim.cmd(":edit " .. file)
end
end)
end
end
---@param query_group string
---@param lang string
function M.edit_query_file_user_after(query_group, lang)
lang = lang or parsers.get_buf_lang()
local folder = utils.join_path(vim.fn.stdpath "config", "after", "queries", lang)
local file = utils.join_path(folder, query_group .. ".scm")
if vim.fn.isdirectory(folder) ~= 1 then
vim.ui.select({ "Yes", "No" }, { prompt = '"' .. folder .. '" does not exist. Create it?' }, function(choice)
if choice == "Yes" then
vim.fn.mkdir(folder, "p", "0755")
vim.cmd(":edit " .. file)
end
end)
else
vim.cmd(":edit " .. file)
end
end
M.commands = {
TSBufEnable = {
run = enable_module,
args = {
"-nargs=1",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSBufDisable = {
run = disable_module,
args = {
"-nargs=1",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSBufToggle = {
run = toggle_module,
args = {
"-nargs=1",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSEnable = {
run = enable_all,
args = {
"-nargs=+",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSDisable = {
run = disable_all,
args = {
"-nargs=+",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSToggle = {
run = toggle_all,
args = {
"-nargs=+",
"-complete=custom,nvim_treesitter#available_modules",
},
},
TSConfigInfo = {
run = config_info,
args = {
"-nargs=0",
},
},
TSEditQuery = {
run = M.edit_query_file,
args = {
"-nargs=+",
"-complete=custom,nvim_treesitter#available_query_groups",
},
},
TSEditQueryUserAfter = {
run = M.edit_query_file_user_after,
args = {
"-nargs=+",
"-complete=custom,nvim_treesitter#available_query_groups",
},
},
}
---@param mod string module
---@param lang string the language of the buffer
---@param bufnr integer the buffer
function M.is_enabled(mod, lang, bufnr)
if not parsers.has_parser(lang) then
return false
end
local module_config = M.get_module(mod)
if not module_config then
return false
end
local buffer_enabled = module_config.enabled_buffers and module_config.enabled_buffers[bufnr]
local config_enabled = module_config.enable or buffer_enabled
if not config_enabled or not module_config.is_supported(lang) then
return false
end
local disable = module_config.disable
if type(disable) == "function" then
if disable(lang, bufnr) then
return false
end
elseif type(disable) == "table" then
-- Otherwise it's a list of languages
for _, parser in pairs(disable) do
if lang == parser then
return false
end
end
end
return true
end
---Setup call for users to override module configurations.
---@param user_data TSConfig module overrides
function M.setup(user_data)
config.modules = vim.tbl_deep_extend("force", config.modules, user_data)
config.ignore_install = user_data.ignore_install or {}
config.parser_install_dir = user_data.parser_install_dir or nil
if config.parser_install_dir then
config.parser_install_dir = vim.fn.expand(config.parser_install_dir, ":p")
end
config.auto_install = user_data.auto_install or false
if config.auto_install then
require("nvim-treesitter.install").setup_auto_install()
end
local ensure_installed = user_data.ensure_installed or {}
if #ensure_installed > 0 then
if user_data.sync_install then
require("nvim-treesitter.install").ensure_installed_sync(ensure_installed)
else
require("nvim-treesitter.install").ensure_installed(ensure_installed)
end
end
config.modules.ensure_installed = nil
config.ensure_installed = ensure_installed
recurse_modules(function(_, _, new_path)
local data = utils.get_at_path(config.modules, new_path)
if data.enable then
enable_all(new_path)
end
end, config.modules)
end
-- Defines a table of modules that can be attached/detached to buffers
-- based on language support. A module consist of the following properties:
---* @enable Whether the modules is enabled. Can be true or false.
---* @disable A list of languages to disable the module for. Only relevant if enable is true.
---* @keymaps A list of user mappings for a given module if relevant.
---* @is_supported A function which, given a ft, will return true if the ft works on the module.
---* @module_path A string path to a module file using `require`. The exported module must contain
--- an `attach` and `detach` function. This path is not required if `attach` and `detach`
--- functions are provided directly on the module definition.
---* @attach An attach function that is called for each buffer that the module is enabled for. This is required
--- if a `module_path` is not specified.
---* @detach A detach function that is called for each buffer that the module is enabled for. This is required
--- if a `module_path` is not specified.
--
-- Modules are not setup until `init` is invoked by the plugin. This allows modules to be defined in any order
-- and can be loaded lazily.
--
---* @example
---require"nvim-treesitter".define_modules {
--- my_cool_module = {
--- attach = function()
--- do_some_cool_setup()
--- end,
--- detach = function()
--- do_some_cool_teardown()
--- end
--- }
---}
---@param mod_defs TSModule[]
function M.define_modules(mod_defs)
if not is_initialized then
table.insert(queued_modules_defs, mod_defs)
return
end
recurse_modules(function(key, mod, _, group)
group[key] = vim.tbl_extend("keep", mod, {
enable = false,
disable = {},
is_supported = function()
return true
end,
})
end, mod_defs)
config.modules = vim.tbl_deep_extend("keep", config.modules, mod_defs)
for _, mod in ipairs(M.available_modules(mod_defs)) do
local module_config = M.get_module(mod)
if module_config and module_config.enable then
enable_mod_conf_autocmd(mod)
end
end
end
---Attaches a module to a buffer
---@param mod_name string the module name
---@param bufnr integer the buffer
---@param lang string the language of the buffer
function M.attach_module(mod_name, bufnr, lang)
bufnr = bufnr or api.nvim_get_current_buf()
lang = lang or parsers.get_buf_lang(bufnr)
local resolved_mod = resolve_module(mod_name)
if resolved_mod and not attached_buffers_by_module.has(mod_name, bufnr) and M.is_enabled(mod_name, lang, bufnr) then
attached_buffers_by_module.set(mod_name, bufnr, true)
resolved_mod.attach(bufnr, lang)
end
end
-- Detaches a module to a buffer
---@param mod_name string the module name
---@param bufnr integer the buffer
function M.detach_module(mod_name, bufnr)
local resolved_mod = resolve_module(mod_name)
bufnr = bufnr or api.nvim_get_current_buf()
if resolved_mod and attached_buffers_by_module.has(mod_name, bufnr) then
attached_buffers_by_module.remove(mod_name, bufnr)
resolved_mod.detach(bufnr)
end
end
-- Same as attach_module, but if the module is already attached, detach it first.
---@param mod_name string the module name
---@param bufnr integer the buffer
---@param lang string the language of the buffer
function M.reattach_module(mod_name, bufnr, lang)
M.detach_module(mod_name, bufnr)
M.attach_module(mod_name, bufnr, lang)
end
-- Gets available modules
---@param root {[string]:TSModule}|nil table to find modules
---@return string[] modules list of module paths
function M.available_modules(root)
local modules = {}
recurse_modules(function(_, _, path)
table.insert(modules, path)
end, root)
return modules
end
---Gets a module config by path
---@param mod_path string path to the module
---@return TSModule|nil: the module or nil
function M.get_module(mod_path)
local mod = utils.get_at_path(config.modules, mod_path)
return M.is_module(mod) and mod or nil
end
-- Determines whether the provided table is a module.
-- A module should contain an attach and detach function.
---@param mod table|nil the module table
---@return boolean
function M.is_module(mod)
return type(mod) == "table"
and ((type(mod.attach) == "function" and type(mod.detach) == "function") or type(mod.module_path) == "string")
end
-- Initializes built-in modules and any queued modules
-- registered by plugins or the user.
function M.init()
is_initialized = true
M.define_modules(builtin_modules)
for _, mod_def in ipairs(queued_modules_defs) do
M.define_modules(mod_def)
end
end
-- If parser_install_dir is not nil is used or created.
-- If parser_install_dir is nil try the package dir of the nvim-treesitter
-- plugin first, followed by the "site" dir from "runtimepath". "site" dir will
-- be created if it doesn't exist. Using only the package dir won't work when
-- the plugin is installed with Nix, since the "/nix/store" is read-only.
---@param folder_name string|nil
---@return string|nil, string|nil
function M.get_parser_install_dir(folder_name)
folder_name = folder_name or "parser"
local install_dir = config.parser_install_dir or utils.get_package_path()
local parser_dir = utils.join_path(install_dir, folder_name)
return utils.create_or_reuse_writable_dir(
parser_dir,
utils.join_space("Could not create parser dir '", parser_dir, "': "),
utils.join_space(
"Parser dir '",
parser_dir,
"' should be read/write (see README on how to configure an alternative install location)"
)
)
end
function M.get_parser_info_dir()
return M.get_parser_install_dir "parser-info"
end
function M.get_ignored_parser_installs()
return config.ignore_install or {}
end
function M.get_ensure_installed_parsers()
if type(config.ensure_installed) == "string" then
return { config.ensure_installed }
end
return config.ensure_installed or {}
end
return M

View file

@ -1,123 +0,0 @@
local api = vim.api
local tsutils = require "nvim-treesitter.ts_utils"
local query = require "nvim-treesitter.query"
local parsers = require "nvim-treesitter.parsers"
local M = {}
-- This is cached on buf tick to avoid computing that multiple times
-- Especially not for every line in the file when `zx` is hit
local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr)
local max_fold_level = api.nvim_win_get_option(0, "foldnestmax")
local trim_level = function(level)
if level > max_fold_level then
return max_fold_level
end
return level
end
local parser = parsers.get_parser(bufnr)
if not parser then
return {}
end
local matches = query.get_capture_matches_recursively(bufnr, function(lang)
if query.has_folds(lang) then
return "@fold", "folds"
elseif query.has_locals(lang) then
return "@scope", "locals"
end
end)
-- start..stop is an inclusive range
---@type table<number, number>
local start_counts = {}
---@type table<number, number>
local stop_counts = {}
local prev_start = -1
local prev_stop = -1
local min_fold_lines = api.nvim_win_get_option(0, "foldminlines")
for _, match in ipairs(matches) do
local start, stop, stop_col ---@type integer, integer, integer
if match.metadata and match.metadata.range then
start, _, stop, stop_col = unpack(match.metadata.range) ---@type integer, integer, integer, integer
else
start, _, stop, stop_col = match.node:range() ---@type integer, integer, integer, integer
end
if stop_col == 0 then
stop = stop - 1
end
local fold_length = stop - start + 1
local should_fold = fold_length > min_fold_lines
-- Fold only multiline nodes that are not exactly the same as previously met folds
-- Checking against just the previously found fold is sufficient if nodes
-- are returned in preorder or postorder when traversing tree
if should_fold and not (start == prev_start and stop == prev_stop) then
start_counts[start] = (start_counts[start] or 0) + 1
stop_counts[stop] = (stop_counts[stop] or 0) + 1
prev_start = start
prev_stop = stop
end
end
---@type string[]
local levels = {}
local current_level = 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
for lnum = 0, api.nvim_buf_line_count(bufnr) do
local prefix = ""
local last_trimmed_level = trim_level(current_level)
current_level = current_level + (start_counts[lnum] or 0)
local trimmed_level = trim_level(current_level)
current_level = current_level - (stop_counts[lnum] or 0)
local next_trimmed_level = trim_level(current_level)
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
-- two (or more) folds start at this line, so it cannot distinguish between
-- ( \n ( \n )) \n (( \n ) \n )
-- versus
-- ( \n ( \n ) \n ( \n ) \n )
-- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
-- would be the correct number of starts to pass on.
if trimmed_level - last_trimmed_level > 0 then
prefix = ">"
elseif trimmed_level - next_trimmed_level > 0 then
-- Ending marks tend to confuse vim more than it helps, particularly when
-- the fold level changes by at least 2; we can uncomment this if
-- vim's behavior gets fixed.
-- prefix = "<"
prefix = ""
end
levels[lnum + 1] = prefix .. tostring(trimmed_level)
end
return levels
end)
---@param lnum integer
---@return string
function M.get_fold_indic(lnum)
if not parsers.has_parser() or not lnum then
return "0"
end
local buf = api.nvim_get_current_buf()
local levels = folds_levels(buf) or {}
return levels[lnum] or "0"
end
return M

View file

@ -1,116 +1,119 @@
local api = vim.api
local fn = vim.fn
local queries = require "nvim-treesitter.query"
local info = require "nvim-treesitter.info"
local shell = require "nvim-treesitter.shell_command_selectors"
local install = require "nvim-treesitter.install"
local utils = require "nvim-treesitter.utils"
local ts = require "nvim-treesitter.compat"
local health = vim.health or require "health"
-- "report_" prefix has been deprecated, use the recommended replacements if they exist.
local _start = health.start or health.report_start
local _ok = health.ok or health.report_ok
local _warn = health.warn or health.report_warn
local _error = health.error or health.report_error
local shell = require('nvim-treesitter.shell_cmds')
local install = require('nvim-treesitter.install')
local config = require('nvim-treesitter.config')
local tsq = vim.treesitter.query
local M = {}
local NVIM_TREESITTER_MINIMUM_ABI = 13
local function install_health()
_start "Installation"
if fn.has "nvim-0.9.1" ~= 1 then
_error "Nvim-treesitter requires Neovim 0.9.1+"
end
if fn.executable "tree-sitter" == 0 then
_warn(
"`tree-sitter` executable not found (parser generator, only needed for :TSInstallFromGrammar,"
.. " not required for :TSInstall)"
)
else
_ok(
"`tree-sitter` found "
.. (utils.ts_cli_version() or "(unknown version)")
.. " (parser generator, only needed for :TSInstallFromGrammar)"
)
end
if fn.executable "node" == 0 then
_warn("`node` executable not found (only needed for :TSInstallFromGrammar," .. " not required for :TSInstall)")
else
local handle = io.popen "node --version"
local result = handle:read "*a"
---@return string|nil
local function ts_cli_version()
if vim.fn.executable('tree-sitter') == 1 then
local handle = io.popen('tree-sitter -V')
if not handle then
return
end
local result = handle:read('*a')
handle:close()
local version = vim.split(result, "\n")[1]
_ok("`node` found " .. version .. " (only needed for :TSInstallFromGrammar)")
return vim.split(result, '\n')[1]:match('[^tree%psitter ].*')
end
end
local function install_health()
vim.health.start('Installation')
if vim.fn.has('nvim-0.10') ~= 1 then
vim.health.error('Nvim-treesitter requires Neovim Nightly')
end
if fn.executable "git" == 0 then
_error("`git` executable not found.", {
"Install it with your package manager.",
"Check that your `$PATH` is set correctly.",
if vim.fn.executable('tree-sitter') == 0 then
vim.health.warn(
'`tree-sitter` executable not found (parser generator, only needed for :TSInstallFromGrammar,'
.. ' not required for :TSInstall)'
)
else
vim.health.ok(
'`tree-sitter` found '
.. (ts_cli_version() or '(unknown version)')
.. ' (parser generator, only needed for :TSInstallFromGrammar)'
)
end
if vim.fn.executable('node') == 0 then
vim.health.warn(
'`node` executable not found (only needed for :TSInstallFromGrammar,'
.. ' not required for :TSInstall)'
)
else
local handle = assert(io.popen('node --version'))
local result = handle:read('*a')
handle:close()
local version = vim.split(result, '\n')[1]
vim.health.ok('`node` found ' .. version .. ' (only needed for :TSInstallFromGrammar)')
end
if vim.fn.executable('git') == 0 then
vim.health.error('`git` executable not found.', {
'Install it with your package manager.',
'Check that your `$PATH` is set correctly.',
})
else
_ok "`git` executable found."
vim.health.ok('`git` executable found.')
end
local cc = shell.select_executable(install.compilers)
if not cc then
_error("`cc` executable not found.", {
"Check that any of "
vim.health.error('`cc` executable not found.', {
'Check that any of '
.. vim.inspect(install.compilers)
.. " is in your $PATH"
.. ' is in your $PATH'
.. ' or set the environment variable CC or `require"nvim-treesitter.install".compilers` explicitly!',
})
else
local version = vim.fn.systemlist(cc .. (cc == "cl" and "" or " --version"))[1]
_ok(
"`"
local version = vim.fn.systemlist(cc .. (cc == 'cl' and '' or ' --version'))[1]
vim.health.ok(
'`'
.. cc
.. "` executable found. Selected from "
.. '` executable found. Selected from '
.. vim.inspect(install.compilers)
.. (version and ("\nVersion: " .. version) or "")
.. (version and ('\nVersion: ' .. version) or '')
)
end
if vim.treesitter.language_version then
if vim.treesitter.language_version >= NVIM_TREESITTER_MINIMUM_ABI then
_ok(
"Neovim was compiled with tree-sitter runtime ABI version "
vim.health.ok(
'Neovim was compiled with tree-sitter runtime ABI version '
.. vim.treesitter.language_version
.. " (required >="
.. ' (required >='
.. NVIM_TREESITTER_MINIMUM_ABI
.. "). Parsers must be compatible with runtime ABI."
.. '). Parsers must be compatible with runtime ABI.'
)
else
_error(
"Neovim was compiled with tree-sitter runtime ABI version "
vim.health.error(
'Neovim was compiled with tree-sitter runtime ABI version '
.. vim.treesitter.language_version
.. ".\n"
.. "nvim-treesitter expects at least ABI version "
.. '.\n'
.. 'nvim-treesitter expects at least ABI version '
.. NVIM_TREESITTER_MINIMUM_ABI
.. "\n"
.. "Please make sure that Neovim is linked against are recent tree-sitter runtime when building"
.. " or raise an issue at your Neovim packager. Parsers must be compatible with runtime ABI."
.. '\n'
.. 'Please make sure that Neovim is linked against are recent tree-sitter runtime when building'
.. ' or raise an issue at your Neovim packager. Parsers must be compatible with runtime ABI.'
)
end
end
_start("OS Info:\n" .. vim.inspect(vim.loop.os_uname()))
vim.health.start('OS Info:\n' .. vim.inspect(vim.loop.os_uname()))
end
local function query_status(lang, query_group)
local ok, err = pcall(queries.get_query, lang, query_group)
local ok, err = pcall(tsq.get, lang, query_group)
if not ok then
return "x", err
return 'x', err
elseif not err then
return "."
return '.'
else
return ""
return ''
end
end
@ -118,59 +121,53 @@ function M.check()
local error_collection = {}
-- Installation dependency checks
install_health()
queries.invalidate_query_cache()
-- Parser installation checks
local parser_installation = { "Parser/Features" .. string.rep(" ", 9) .. "H L F I J" }
for _, parser_name in pairs(info.installed_parsers()) do
local installed = #api.nvim_get_runtime_file("parser/" .. parser_name .. ".so", false)
-- Only append information about installed parsers
if installed >= 1 then
local multiple_parsers = installed > 1 and "+" or ""
local out = " - " .. parser_name .. multiple_parsers .. string.rep(" ", 20 - (#parser_name + #multiple_parsers))
for _, query_group in pairs(queries.built_in_query_groups) do
local status, err = query_status(parser_name, query_group)
out = out .. status .. " "
if err then
table.insert(error_collection, { parser_name, query_group, err })
end
local parser_installation = { 'Parser/Features' .. string.rep(' ', 9) .. 'H L F I J' }
for _, parser_name in pairs(config.installed_parsers()) do
local out = ' - ' .. parser_name .. string.rep(' ', 20 - #parser_name)
for _, query_group in pairs(M.bundled_queries) do
local status, err = query_status(parser_name, query_group)
out = out .. status .. ' '
if err then
table.insert(error_collection, { parser_name, query_group, err })
end
table.insert(parser_installation, vim.fn.trim(out, " ", 2))
end
table.insert(parser_installation, vim.fn.trim(out, ' ', 2))
end
local legend = [[
Legend: H[ighlight], L[ocals], F[olds], I[ndents], In[j]ections
+) multiple parsers found, only one will be used
x) errors found in the query, try to run :TSUpdate {lang}]]
Legend: H[ighlight], L[ocals], F[olds], I[ndents], In[J]ections
x) errors found in the query, try to run :TSUpdate {lang}]]
table.insert(parser_installation, legend)
-- Finally call the report function
_start(table.concat(parser_installation, "\n"))
vim.health.start(table.concat(parser_installation, '\n'))
if #error_collection > 0 then
_start "The following errors have been detected:"
vim.health.start('The following errors have been detected:')
for _, p in ipairs(error_collection) do
local lang, type, err = unpack(p)
local lines = {}
table.insert(lines, lang .. "(" .. type .. "): " .. err)
local files = ts.get_query_files(lang, type)
table.insert(lines, lang .. '(' .. type .. '): ' .. err)
local files = tsq.get_files(lang, type)
if #files > 0 then
table.insert(lines, lang .. "(" .. type .. ") is concatenated from the following files:")
table.insert(lines, lang .. '(' .. type .. ') is concatenated from the following files:')
for _, file in ipairs(files) do
local fd = io.open(file, "r")
local fd = io.open(file, 'r')
if fd then
local ok, file_err = pcall(ts.parse_query, lang, fd:read "*a")
local ok, file_err = pcall(tsq.parse, lang, fd:read('*a'))
if ok then
table.insert(lines, '| [OK]:"' .. file .. '"')
table.insert(lines, '| [OK]:"' .. file .. '"')
else
table.insert(lines, '| [ERROR]:"' .. file .. '", failed to load: ' .. file_err)
table.insert(lines, '| [ERR]:"' .. file .. '", failed to load: ' .. file_err)
end
fd:close()
end
end
end
_error(table.concat(lines, "\n"))
vim.health.error(table.concat(lines, '\n'))
end
end
end
M.bundled_queries = { 'highlights', 'locals', 'folds', 'indents', 'injections' }
return M

View file

@ -1,49 +0,0 @@
local configs = require "nvim-treesitter.configs"
local M = {}
---@param config TSModule
---@param lang string
---@return boolean
local function should_enable_vim_regex(config, lang)
local additional_hl = config.additional_vim_regex_highlighting
local is_table = type(additional_hl) == "table"
---@diagnostic disable-next-line: param-type-mismatch
return additional_hl and (not is_table or vim.tbl_contains(additional_hl, lang))
end
---@param bufnr integer
---@param lang string
function M.attach(bufnr, lang)
local config = configs.get_module "highlight"
vim.treesitter.start(bufnr, lang)
if config and should_enable_vim_regex(config, lang) then
vim.bo[bufnr].syntax = "ON"
end
end
---@param bufnr integer
function M.detach(bufnr)
vim.treesitter.stop(bufnr)
end
---@deprecated
function M.start(...)
vim.notify(
"`nvim-treesitter.highlight.start` is deprecated: use `nvim-treesitter.highlight.attach` or `vim.treesitter.start`",
vim.log.levels.WARN
)
M.attach(...)
end
---@deprecated
function M.stop(...)
vim.notify(
"`nvim-treesitter.highlight.stop` is deprecated: use `nvim-treesitter.highlight.detach` or `vim.treesitter.stop`",
vim.log.levels.WARN
)
M.detach(...)
end
return M

View file

@ -1,176 +0,0 @@
local api = vim.api
local configs = require "nvim-treesitter.configs"
local ts_utils = require "nvim-treesitter.ts_utils"
local locals = require "nvim-treesitter.locals"
local parsers = require "nvim-treesitter.parsers"
local queries = require "nvim-treesitter.query"
local M = {}
---@type table<integer, table<TSNode|nil>>
local selections = {}
function M.init_selection()
local buf = api.nvim_get_current_buf()
local node = ts_utils.get_node_at_cursor()
selections[buf] = { [1] = node }
ts_utils.update_selection(buf, node)
end
-- Get the range of the current visual selection.
--
-- The range starts with 1 and the ending is inclusive.
---@return integer, integer, integer, integer
local function visual_selection_range()
local _, csrow, cscol, _ = unpack(vim.fn.getpos "'<") ---@type integer, integer, integer, integer
local _, cerow, cecol, _ = unpack(vim.fn.getpos "'>") ---@type integer, integer, integer, integer
local start_row, start_col, end_row, end_col ---@type integer, integer, integer, integer
if csrow < cerow or (csrow == cerow and cscol <= cecol) then
start_row = csrow
start_col = cscol
end_row = cerow
end_col = cecol
else
start_row = cerow
start_col = cecol
end_row = csrow
end_col = cscol
end
return start_row, start_col, end_row, end_col
end
---@param node TSNode
---@return boolean
local function range_matches(node)
local csrow, cscol, cerow, cecol = visual_selection_range()
local srow, scol, erow, ecol = ts_utils.get_vim_range { node:range() }
return srow == csrow and scol == cscol and erow == cerow and ecol == cecol
end
---@param get_parent fun(node: TSNode): TSNode|nil
---@return fun():nil
local function select_incremental(get_parent)
return function()
local buf = api.nvim_get_current_buf()
local nodes = selections[buf]
local csrow, cscol, cerow, cecol = visual_selection_range()
-- Initialize incremental selection with current selection
if not nodes or #nodes == 0 or not range_matches(nodes[#nodes]) then
local root = parsers.get_parser():parse()[1]:root()
local node = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
ts_utils.update_selection(buf, node)
if nodes and #nodes > 0 then
table.insert(selections[buf], node)
else
selections[buf] = { [1] = node }
end
return
end
-- Find a node that changes the current selection.
local node = nodes[#nodes] ---@type TSNode
while true do
local parent = get_parent(node)
if not parent or parent == node then
-- Keep searching in the main tree
-- TODO: we should search on the parent tree of the current node.
local root = parsers.get_parser():parse()[1]:root()
parent = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
if not parent or root == node or parent == node then
ts_utils.update_selection(buf, node)
return
end
end
node = parent
local srow, scol, erow, ecol = ts_utils.get_vim_range { node:range() }
local same_range = (srow == csrow and scol == cscol and erow == cerow and ecol == cecol)
if not same_range then
table.insert(selections[buf], node)
if node ~= nodes[#nodes] then
table.insert(nodes, node)
end
ts_utils.update_selection(buf, node)
return
end
end
end
end
M.node_incremental = select_incremental(function(node)
return node:parent() or node
end)
M.scope_incremental = select_incremental(function(node)
local lang = parsers.get_buf_lang()
if queries.has_locals(lang) then
return locals.containing_scope(node:parent() or node)
else
return node
end
end)
function M.node_decremental()
local buf = api.nvim_get_current_buf()
local nodes = selections[buf]
if not nodes or #nodes < 2 then
return
end
table.remove(selections[buf])
local node = nodes[#nodes] ---@type TSNode
ts_utils.update_selection(buf, node)
end
local FUNCTION_DESCRIPTIONS = {
init_selection = "Start selecting nodes with nvim-treesitter",
node_incremental = "Increment selection to named node",
scope_incremental = "Increment selection to surrounding scope",
node_decremental = "Shrink selection to previous named node",
}
---@param bufnr integer
function M.attach(bufnr)
local config = configs.get_module "incremental_selection"
for funcname, mapping in pairs(config.keymaps) do
if mapping then
---@type string, string|function
local mode, rhs
if funcname == "init_selection" then
mode = "n"
---@type function
rhs = M[funcname]
else
mode = "x"
-- We need to move to command mode to access marks '< (visual area start) and '> (visual area end) which are not
-- properly accessible in visual mode.
rhs = string.format(":lua require'nvim-treesitter.incremental_selection'.%s()<CR>", funcname)
end
vim.keymap.set(
mode,
mapping,
rhs,
{ buffer = bufnr, silent = true, noremap = true, desc = FUNCTION_DESCRIPTIONS[funcname] }
)
end
end
end
function M.detach(bufnr)
local config = configs.get_module "incremental_selection"
for f, mapping in pairs(config.keymaps) do
if mapping then
if f == "init_selection" then
vim.keymap.del("n", mapping, { buffer = bufnr })
else
vim.keymap.del("x", mapping, { buffer = bufnr })
end
end
end
end
return M

View file

@ -1,5 +1,4 @@
local ts = vim.treesitter
local parsers = require "nvim-treesitter.parsers"
local M = {}
@ -14,7 +13,7 @@ M.comment_parsers = {
}
local function getline(lnum)
return vim.api.nvim_buf_get_lines(0, lnum - 1, lnum, false)[1] or ""
return vim.api.nvim_buf_get_lines(0, lnum - 1, lnum, false)[1] or ''
end
---@param root TSNode
@ -55,8 +54,8 @@ local function find_delimiter(bufnr, node, delimiter)
local line = vim.api.nvim_buf_get_lines(bufnr, linenr, linenr + 1, false)[1]
local end_char = { child:end_() }
local trimmed_after_delim
local escaped_delimiter = delimiter:gsub("[%-%.%+%[%]%(%)%$%^%%%?%*]", "%%%1")
trimmed_after_delim, _ = line:sub(end_char[2] + 1):gsub("[%s" .. escaped_delimiter .. "]*", "")
local escaped_delimiter = delimiter:gsub('[%-%.%+%[%]%(%)%$%^%%%?%*]', '%%%1')
trimmed_after_delim = line:sub(end_char[2] + 1):gsub('[%s' .. escaped_delimiter .. ']*', '')
return child, #trimmed_after_delim == 0
end
end
@ -68,7 +67,7 @@ end
---@param hash_fn fun(...): any
---@return F
local function memoize(fn, hash_fn)
local cache = setmetatable({}, { __mode = "kv" }) ---@type table<any,any>
local cache = setmetatable({}, { __mode = 'kv' }) ---@type table<any,any>
return function(...)
local key = hash_fn(...)
@ -84,42 +83,42 @@ end
local get_indents = memoize(function(bufnr, root, lang)
local map = {
["indent.auto"] = {},
["indent.begin"] = {},
["indent.end"] = {},
["indent.dedent"] = {},
["indent.branch"] = {},
["indent.ignore"] = {},
["indent.align"] = {},
["indent.zero"] = {},
['indent.auto'] = {},
['indent.begin'] = {},
['indent.end'] = {},
['indent.dedent'] = {},
['indent.branch'] = {},
['indent.ignore'] = {},
['indent.align'] = {},
['indent.zero'] = {},
}
--TODO(clason): remove when dropping Nvim 0.8 compat
local query = (ts.query.get or ts.get_query)(lang, "indents")
local query = ts.query.get(lang, 'indents')
if not query then
return map
end
for id, node, metadata in query:iter_captures(root, bufnr) do
if query.captures[id]:sub(1, 1) ~= "_" then
if query.captures[id]:sub(1, 1) ~= '_' then
map[query.captures[id]][node:id()] = metadata or {}
end
end
return map
end, function(bufnr, root, lang)
return tostring(bufnr) .. root:id() .. "_" .. lang
return tostring(bufnr) .. root:id() .. '_' .. lang
end)
---@param lnum number (1-indexed)
---@return integer
function M.get_indent(lnum)
local bufnr = vim.api.nvim_get_current_buf()
local parser = parsers.get_parser(bufnr)
local parser = ts.get_parser(bufnr)
if not parser or not lnum then
return -1
end
--TODO(clason): replace when dropping Nvim 0.8 compat
local root_lang = parsers.get_buf_lang(bufnr)
local ft = vim.bo[bufnr].filetype
local root_lang = vim.treesitter.language.get_lang(ft) or ft
-- some languages like Python will actually have worse results when re-parsing at opened new line
if not M.avoid_force_reparsing[root_lang] then
@ -148,15 +147,14 @@ function M.get_indent(lnum)
end
local q = get_indents(vim.api.nvim_get_current_buf(), root, lang_tree:lang())
local is_empty_line = string.match(getline(lnum), "^%s*$") ~= nil
local node ---@type TSNode
if is_empty_line then
if getline(lnum):find('^%s*$') then
local prevlnum = vim.fn.prevnonblank(lnum)
local indent = vim.fn.indent(prevlnum)
local prevline = vim.trim(getline(prevlnum))
-- The final position can be trailing spaces, which should not affect indentation
node = get_last_node_at_line(root, prevlnum, indent + #prevline - 1)
if node:type():match "comment" then
if node:type():find('comment') then
-- The final node we capture of the previous line can be a comment node, which should also be ignored
-- Unless the last line is an entire line of comment, ignore the comment range and find the last node again
local first_node = get_first_node_at_line(root, prevlnum, indent)
@ -169,7 +167,7 @@ function M.get_indent(lnum)
node = get_last_node_at_line(root, prevlnum, col)
end
end
if q["indent.end"][node:id()] then
if q['indent.end'][node:id()] then
node = get_first_node_at_line(root, lnum)
end
else
@ -185,18 +183,18 @@ function M.get_indent(lnum)
end
-- tracks to ensure multiple indent levels are not applied for same line
local is_processed_by_row = {}
local is_processed_by_row = {} --- @type table<integer,boolean>
if q["indent.zero"][node:id()] then
if q['indent.zero'][node:id()] then
return 0
end
while node do
-- do 'autoindent' if not marked as @indent
if
not q["indent.begin"][node:id()]
and not q["indent.align"][node:id()]
and q["indent.auto"][node:id()]
not q['indent.begin'][node:id()]
and not q['indent.align'][node:id()]
and q['indent.auto'][node:id()]
and node:start() < lnum - 1
and lnum - 1 <= node:end_()
then
@ -207,8 +205,8 @@ function M.get_indent(lnum)
-- If a node spans from L1,C1 to L2,C2, we know that lines where L1 < line <= L2 would
-- have their indentations contained by the node.
if
not q["indent.begin"][node:id()]
and q["indent.ignore"][node:id()]
not q['indent.begin'][node:id()]
and q['indent.ignore'][node:id()]
and node:start() < lnum - 1
and lnum - 1 <= node:end_()
then
@ -221,7 +219,10 @@ function M.get_indent(lnum)
if
not is_processed_by_row[srow]
and ((q["indent.branch"][node:id()] and srow == lnum - 1) or (q["indent.dedent"][node:id()] and srow ~= lnum - 1))
and (
(q['indent.branch'][node:id()] and srow == lnum - 1)
or (q['indent.dedent'][node:id()] and srow ~= lnum - 1)
)
then
indent = indent - indent_size
is_processed = true
@ -237,16 +238,16 @@ function M.get_indent(lnum)
if
should_process
and (
q["indent.begin"][node:id()]
and (srow ~= erow or is_in_err or q["indent.begin"][node:id()]["indent.immediate"])
and (srow ~= lnum - 1 or q["indent.begin"][node:id()]["indent.start_at_same_line"])
q['indent.begin'][node:id()]
and (srow ~= erow or is_in_err or q['indent.begin'][node:id()]['indent.immediate'])
and (srow ~= lnum - 1 or q['indent.begin'][node:id()]['indent.start_at_same_line'])
)
then
indent = indent + indent_size
is_processed = true
end
if is_in_err and not q["indent.align"][node:id()] then
if is_in_err and not q['indent.align'][node:id()] then
-- only when the node is in error, promote the
-- first child's aligned indent to the error node
-- to work around ((ERROR "X" . (_)) @aligned_indent (#set! "delimeter" "AB"))
@ -254,34 +255,41 @@ function M.get_indent(lnum)
-- (ERROR "X" @aligned_indent (#set! "delimeter" "AB") . (_))
-- and we will fish it out here.
for c in node:iter_children() do
if q["indent.align"][c:id()] then
q["indent.align"][node:id()] = q["indent.align"][c:id()]
if q['indent.align'][c:id()] then
q['indent.align'][node:id()] = q['indent.align'][c:id()]
break
end
end
end
-- do not indent for nodes that starts-and-ends on same line and starts on target line (lnum)
if should_process and q["indent.align"][node:id()] and (srow ~= erow or is_in_err) and (srow ~= lnum - 1) then
local metadata = q["indent.align"][node:id()]
if
should_process
and q['indent.align'][node:id()]
and (srow ~= erow or is_in_err)
and (srow ~= lnum - 1)
then
local metadata = q['indent.align'][node:id()]
local o_delim_node, o_is_last_in_line ---@type TSNode|nil, boolean|nil
local c_delim_node, c_is_last_in_line ---@type TSNode|nil, boolean|nil, boolean|nil
local indent_is_absolute = false
if metadata["indent.open_delimiter"] then
o_delim_node, o_is_last_in_line = find_delimiter(bufnr, node, metadata["indent.open_delimiter"])
if metadata['indent.open_delimiter'] then
o_delim_node, o_is_last_in_line =
find_delimiter(bufnr, node, metadata['indent.open_delimiter'])
else
o_delim_node = node
end
if metadata["indent.close_delimiter"] then
c_delim_node, c_is_last_in_line = find_delimiter(bufnr, node, metadata["indent.close_delimiter"])
if metadata['indent.close_delimiter'] then
c_delim_node, c_is_last_in_line =
find_delimiter(bufnr, node, metadata['indent.close_delimiter'])
else
c_delim_node = node
end
if o_delim_node then
local o_srow, o_scol = o_delim_node:start()
local c_srow = nil
local c_srow = nil --- @type integer?
if c_delim_node then
c_srow, _ = c_delim_node:start()
c_srow = c_delim_node:start()
end
if o_is_last_in_line then
-- hanging indent (previous line ended with starting delimiter)
@ -303,7 +311,7 @@ function M.get_indent(lnum)
-- Then its indent level shouldn't be affected by `@aligned_indent` node
indent = math.max(indent - indent_size, 0)
else
indent = o_scol + (metadata["indent.increment"] or 1)
indent = o_scol + (metadata['indent.increment'] or 1)
indent_is_absolute = true
end
end
@ -314,7 +322,7 @@ function M.get_indent(lnum)
-- then this last line may need additional indent to avoid clashes
-- with the next. `indent.avoid_last_matching_next` controls this behavior,
-- for example this is needed for function parameters.
avoid_last_matching_next = metadata["indent.avoid_last_matching_next"] or false
avoid_last_matching_next = metadata['indent.avoid_last_matching_next'] or false
end
if avoid_last_matching_next then
-- last line must be indented more in cases where
@ -343,17 +351,4 @@ function M.get_indent(lnum)
return indent
end
---@type table<integer, string>
local indent_funcs = {}
---@param bufnr integer
function M.attach(bufnr)
indent_funcs[bufnr] = vim.bo.indentexpr
vim.bo.indentexpr = "nvim_treesitter#indent()"
end
function M.detach(bufnr)
vim.bo.indentexpr = indent_funcs[bufnr]
end
return M

View file

@ -1,190 +0,0 @@
local api = vim.api
local configs = require "nvim-treesitter.configs"
local parsers = require "nvim-treesitter.parsers"
local M = {}
local function install_info()
local max_len = 0
for _, ft in pairs(parsers.available_parsers()) do
if #ft > max_len then
max_len = #ft
end
end
local parser_list = parsers.available_parsers()
table.sort(parser_list)
for _, lang in pairs(parser_list) do
local is_installed = #api.nvim_get_runtime_file("parser/" .. lang .. ".so", false) > 0
api.nvim_out_write(lang .. string.rep(" ", max_len - #lang + 1))
if is_installed then
api.nvim_out_write "[✓] installed\n"
elseif pcall(vim.treesitter.inspect_lang, lang) then
api.nvim_out_write "[✗] not installed (but still loaded. Restart Neovim!)\n"
else
api.nvim_out_write "[✗] not installed\n"
end
end
end
-- Sort a list of modules into namespaces.
-- {'mod1', 'mod2.sub1', 'mod2.sub2', 'mod3'}
-- ->
-- { default = {'mod1', 'mod3'}, mod2 = {'sub1', 'sub2'}}
---@param modulelist string[]
---@return table
local function namespace_modules(modulelist)
local modules = {}
for _, module in ipairs(modulelist) do
if module:find "%." then
local namespace, submodule = module:match "^(.*)%.(.*)$"
if not modules[namespace] then
modules[namespace] = {}
end
table.insert(modules[namespace], submodule)
else
if not modules.default then
modules.default = {}
end
table.insert(modules.default, module)
end
end
return modules
end
---@param list string[]
---@return integer length
local function longest_string_length(list)
local length = 0
for _, value in ipairs(list) do
if #value > length then
length = #value
end
end
return length
end
---@param curbuf integer
---@param origbuf integer
---@param parserlist string[]
---@param namespace string
---@param modulelist string[]
local function append_module_table(curbuf, origbuf, parserlist, namespace, modulelist)
local maxlen_parser = longest_string_length(parserlist)
table.sort(modulelist)
-- header
local header = ">> " .. namespace .. string.rep(" ", maxlen_parser - #namespace - 1)
for _, module in pairs(modulelist) do
header = header .. module .. " "
end
api.nvim_buf_set_lines(curbuf, -1, -1, true, { header })
-- actual table
for _, parser in ipairs(parserlist) do
local padding = string.rep(" ", maxlen_parser - #parser + 2)
local line = parser .. padding
local namespace_prefix = (namespace == "default") and "" or namespace .. "."
for _, module in pairs(modulelist) do
local modlen = #module
module = namespace_prefix .. module
if configs.is_enabled(module, parser, origbuf) then
line = line .. ""
else
line = line .. ""
end
line = line .. string.rep(" ", modlen + 1)
end
api.nvim_buf_set_lines(curbuf, -1, -1, true, { line })
end
api.nvim_buf_set_lines(curbuf, -1, -1, true, { "" })
end
local function print_info_modules(parserlist, module)
local origbuf = api.nvim_get_current_buf()
api.nvim_command "enew"
local curbuf = api.nvim_get_current_buf()
local modules
if module then
modules = namespace_modules { module }
else
modules = namespace_modules(configs.available_modules())
end
---@type string[]
local namespaces = {}
for k, _ in pairs(modules) do
table.insert(namespaces, k)
end
table.sort(namespaces)
table.sort(parserlist)
for _, namespace in ipairs(namespaces) do
append_module_table(curbuf, origbuf, parserlist, namespace, modules[namespace])
end
api.nvim_buf_set_option(curbuf, "modified", false)
api.nvim_buf_set_option(curbuf, "buftype", "nofile")
vim.cmd [[
syntax match TSModuleInfoGood //
syntax match TSModuleInfoBad //
syntax match TSModuleInfoHeader /^>>.*$/ contains=TSModuleInfoNamespace
syntax match TSModuleInfoNamespace /^>> \w*/ contained
syntax match TSModuleInfoParser /^[^> ]*\ze /
]]
local highlights = {
TSModuleInfoGood = { fg = "LightGreen", bold = true, default = true },
TSModuleInfoBad = { fg = "Crimson", default = true },
TSModuleInfoHeader = { link = "Type", default = true },
TSModuleInfoNamespace = { link = "Statement", default = true },
TSModuleInfoParser = { link = "Identifier", default = true },
}
for k, v in pairs(highlights) do
api.nvim_set_hl(0, k, v)
end
end
local function module_info(module)
if module and not configs.get_module(module) then
return
end
local parserlist = parsers.available_parsers()
if module then
print_info_modules(parserlist, module)
else
print_info_modules(parserlist)
end
end
---@return string[]
function M.installed_parsers()
local installed = {}
for _, p in pairs(parsers.available_parsers()) do
if parsers.has_parser(p) then
table.insert(installed, p)
end
end
return installed
end
M.commands = {
TSInstallInfo = {
run = install_info,
args = {
"-nargs=0",
},
},
TSModuleInfo = {
run = module_info,
args = {
"-nargs=?",
"-complete=custom,nvim_treesitter#available_modules",
},
},
}
return M

View file

@ -0,0 +1,11 @@
local M = {}
function M.setup(...)
require('nvim-treesitter.config').setup(...)
end
function M.indentexpr()
return require('nvim-treesitter.indent').get_indent(vim.v.lnum)
end
return M

File diff suppressed because it is too large Load diff

View file

@ -1,29 +1,46 @@
-- Functions to handle locals
-- Locals are a generalization of definition and scopes
-- its the way nvim-treesitter uses to "understand" the code
-- it's the way nvim-treesitter uses to "understand" the code
local queries = require "nvim-treesitter.query"
local ts_utils = require "nvim-treesitter.ts_utils"
local ts = vim.treesitter
local query = require('nvim-treesitter.query')
local api = vim.api
local ts = vim.treesitter
local M = {}
function M.collect_locals(bufnr)
return queries.collect_group_results(bufnr, "locals")
local function get_named_children(node)
local nodes = {} ---@type TSNode[]
for i = 0, node:named_child_count() - 1, 1 do
nodes[i + 1] = node:named_child(i)
end
return nodes
end
---@param node TSNode
---@return TSNode result
local function get_root_for_node(node)
local parent = node
local result = node
while parent ~= nil do
result = parent
parent = result:parent()
end
return result
end
-- Iterates matches from a locals query file.
-- @param bufnr the buffer
-- @param root the root node
function M.iter_locals(bufnr, root)
return queries.iter_group_results(bufnr, "locals", root)
return query.iter_group_results(bufnr, 'locals', root)
end
---@param bufnr integer
---@return any
function M.get_locals(bufnr)
return queries.get_matches(bufnr, "locals")
function M.collect_locals(bufnr)
return query.collect_group_results(bufnr, 'locals')
end
-- Creates unique id for a node based on text and range
@ -32,11 +49,11 @@ end
---@return string: a string id
function M.get_definition_id(scope, node_text)
-- Add a valid starting character in case node text doesn't start with a valid one.
return table.concat({ "k", node_text or "", scope:range() }, "_")
return table.concat({ 'k', node_text or '', scope:range() }, '_')
end
function M.get_definitions(bufnr)
local locals = M.get_locals(bufnr)
local locals = M.collect_locals(bufnr)
local defs = {}
@ -50,7 +67,7 @@ function M.get_definitions(bufnr)
end
function M.get_scopes(bufnr)
local locals = M.get_locals(bufnr)
local locals = M.collect_locals(bufnr)
local scopes = {}
@ -64,7 +81,7 @@ function M.get_scopes(bufnr)
end
function M.get_references(bufnr)
local locals = M.get_locals(bufnr)
local locals = M.collect_locals(bufnr)
local refs = {}
@ -103,7 +120,7 @@ function M.iter_scope_tree(node, bufnr)
return
end
local scope = M.containing_scope(last_node, bufnr, false) or ts_utils.get_root_for_node(node)
local scope = M.containing_scope(last_node, bufnr, false) or get_root_for_node(node)
last_node = scope:parent()
@ -117,8 +134,8 @@ end
function M.get_local_nodes(local_def)
local result = {}
M.recurse_local_nodes(local_def, function(def, _node, kind)
table.insert(result, vim.tbl_extend("keep", { kind = kind }, def))
M.recurse_local_nodes(local_def, function(def, _, kind)
table.insert(result, vim.tbl_extend('keep', { kind = kind }, def))
end)
return result
@ -135,7 +152,7 @@ end
---@param full_match? string The full match path to append to
---@param last_match? string The last match
function M.recurse_local_nodes(local_def, accumulator, full_match, last_match)
if type(local_def) ~= "table" then
if type(local_def) ~= 'table' then
return
end
@ -143,11 +160,36 @@ function M.recurse_local_nodes(local_def, accumulator, full_match, last_match)
accumulator(local_def, local_def.node, full_match, last_match)
else
for match_key, def in pairs(local_def) do
M.recurse_local_nodes(def, accumulator, full_match and (full_match .. "." .. match_key) or match_key, match_key)
M.recurse_local_nodes(
def,
accumulator,
full_match and (full_match .. '.' .. match_key) or match_key,
match_key
)
end
end
end
---Memoize a function using hash_fn to hash the arguments.
---@generic F: function
---@param fn F
---@param hash_fn fun(...): any
---@return F
local function memoize(fn, hash_fn)
local cache = setmetatable({}, { __mode = 'kv' }) ---@type table<any,any>
return function(...)
local key = hash_fn(...)
if cache[key] == nil then
local v = fn(...) ---@type any
cache[key] = v ~= nil and v or vim.NIL
end
local v = cache[key]
return v ~= vim.NIL and v or nil
end
end
-- Get a single dimension table to look definition nodes.
-- Keys are generated by using the range of the containing scope and the text of the definition node.
-- This makes looking up a definition for a given scope a simple key lookup.
@ -161,7 +203,7 @@ end
--
---@param bufnr integer: the buffer
---@return table result: a table for looking up definitions
M.get_definitions_lookup_table = ts_utils.memoize_by_buf_tick(function(bufnr)
M.get_definitions_lookup_table = memoize(function(bufnr)
local definitions = M.get_definitions(bufnr)
local result = {}
@ -178,6 +220,8 @@ M.get_definitions_lookup_table = ts_utils.memoize_by_buf_tick(function(bufnr)
end
return result
end, function(bufnr)
return tostring(bufnr)
end)
-- Gets all the scopes of a definition based on the scope type
@ -196,10 +240,10 @@ function M.get_definition_scopes(node, bufnr, scope_type)
-- Definition is valid for the containing scope
-- and the containing scope of that scope
if scope_type == "parent" then
if scope_type == 'parent' then
scope_count = 2
-- Definition is valid in all parent scopes
elseif scope_type == "global" then
elseif scope_type == 'global' then
scope_count = nil
end
@ -235,7 +279,7 @@ function M.find_definition(node, bufnr)
end
end
return node, ts_utils.get_root_for_node(node), nil
return node, get_root_for_node(node), nil
end
-- Finds usages of a node in a given scope.
@ -250,11 +294,15 @@ function M.find_usages(node, scope_node, bufnr)
return {}
end
local scope_node = scope_node or ts_utils.get_root_for_node(node)
scope_node = scope_node or get_root_for_node(node)
local usages = {}
for match in M.iter_locals(bufnr, scope_node) do
if match.reference and match.reference.node and ts.get_node_text(match.reference.node, bufnr) == node_text then
if
match.reference
and match.reference.node
and ts.get_node_text(match.reference.node, bufnr) == node_text
then
local def_node, _, kind = M.find_definition(match.reference.node, bufnr)
if kind == nil or def_node == node then
@ -271,8 +319,8 @@ end
---@param allow_scope? boolean
---@return TSNode|nil
function M.containing_scope(node, bufnr, allow_scope)
local bufnr = bufnr or api.nvim_get_current_buf()
local allow_scope = allow_scope == nil or allow_scope == true
bufnr = bufnr or api.nvim_get_current_buf()
allow_scope = allow_scope == nil or allow_scope == true
local scopes = M.get_scopes(bufnr)
if not node or not scopes then
@ -300,7 +348,7 @@ function M.nested_scope(node, cursor_pos)
local col = cursor_pos.col ---@type integer
local scope = M.containing_scope(node)
for _, child in ipairs(ts_utils.get_named_children(scope)) do
for _, child in ipairs(get_named_children(scope)) do
local row_, col_ = child:start()
if vim.tbl_contains(scopes, child) and ((row_ + 1 == row and col_ > col) or row_ + 1 > row) then
return child
@ -317,6 +365,9 @@ function M.next_scope(node)
end
local scope = M.containing_scope(node)
if not scope then
return
end
local parent = scope:parent()
if not parent then
@ -324,7 +375,7 @@ function M.next_scope(node)
end
local is_prev = true
for _, child in ipairs(ts_utils.get_named_children(parent)) do
for _, child in ipairs(get_named_children(parent)) do
if child == scope then
is_prev = false
elseif not is_prev and vim.tbl_contains(scopes, child) then
@ -344,6 +395,9 @@ function M.previous_scope(node)
end
local scope = M.containing_scope(node)
if not scope then
return
end
local parent = scope:parent()
if not parent then
@ -351,7 +405,7 @@ function M.previous_scope(node)
end
local is_prev = true
local children = ts_utils.get_named_children(parent)
local children = get_named_children(parent)
for i = #children, 1, -1 do
if children[i] == scope then
is_prev = false

File diff suppressed because it is too large Load diff

View file

@ -1,154 +1,7 @@
local api = vim.api
local ts = require "nvim-treesitter.compat"
local tsrange = require "nvim-treesitter.tsrange"
local utils = require "nvim-treesitter.utils"
local parsers = require "nvim-treesitter.parsers"
local caching = require "nvim-treesitter.caching"
local M = {}
local EMPTY_ITER = function() end
M.built_in_query_groups = { "highlights", "locals", "folds", "indents", "injections" }
-- Creates a function that checks whether a given query exists
-- for a specific language.
---@param query string
---@return fun(string): boolean
local function get_query_guard(query)
return function(lang)
return M.has_query_files(lang, query)
end
end
for _, query in ipairs(M.built_in_query_groups) do
M["has_" .. query] = get_query_guard(query)
end
---@return string[]
function M.available_query_groups()
local query_files = api.nvim_get_runtime_file("queries/*/*.scm", true)
local groups = {}
for _, f in ipairs(query_files) do
groups[vim.fn.fnamemodify(f, ":t:r")] = true
end
local list = {}
for k, _ in pairs(groups) do
table.insert(list, k)
end
return list
end
do
local query_cache = caching.create_buffer_cache()
local function update_cached_matches(bufnr, changed_tick, query_group)
query_cache.set(query_group, bufnr, {
tick = changed_tick,
cache = M.collect_group_results(bufnr, query_group) or {},
})
end
---@param bufnr integer
---@param query_group string
---@return any
function M.get_matches(bufnr, query_group)
bufnr = bufnr or api.nvim_get_current_buf()
local cached_local = query_cache.get(query_group, bufnr)
if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
update_cached_matches(bufnr, api.nvim_buf_get_changedtick(bufnr), query_group)
end
return query_cache.get(query_group, bufnr).cache
end
end
---@param lang string
---@param query_name string
---@return string[]
local function runtime_queries(lang, query_name)
return api.nvim_get_runtime_file(string.format("queries/%s/%s.scm", lang, query_name), true) or {}
end
---@type table<string, table<string, boolean>>
local query_files_cache = {}
---@param lang string
---@param query_name string
---@return boolean
function M.has_query_files(lang, query_name)
if not query_files_cache[lang] then
query_files_cache[lang] = {}
end
if query_files_cache[lang][query_name] == nil then
local files = runtime_queries(lang, query_name)
query_files_cache[lang][query_name] = files and #files > 0
end
return query_files_cache[lang][query_name]
end
do
local mt = {}
mt.__index = function(tbl, key)
if rawget(tbl, key) == nil then
rawset(tbl, key, {})
end
return rawget(tbl, key)
end
-- cache will auto set the table for each lang if it is nil
---@type table<string, table<string, Query>>
local cache = setmetatable({}, mt)
-- Same as `vim.treesitter.query` except will return cached values
---@param lang string
---@param query_name string
function M.get_query(lang, query_name)
if cache[lang][query_name] == nil then
cache[lang][query_name] = ts.get_query(lang, query_name)
end
return cache[lang][query_name]
end
-- Invalidates the query file cache.
--
-- If lang and query_name is both present, will reload for only the lang and query_name.
-- If only lang is present, will reload all query_names for that lang
-- If none are present, will reload everything
---@param lang? string
---@param query_name? string
function M.invalidate_query_cache(lang, query_name)
if lang and query_name then
cache[lang][query_name] = nil
if query_files_cache[lang] then
query_files_cache[lang][query_name] = nil
end
elseif lang and not query_name then
query_files_cache[lang] = nil
for query_name0, _ in pairs(cache[lang]) do
M.invalidate_query_cache(lang, query_name0)
end
elseif not lang and not query_name then
query_files_cache = {}
for lang0, _ in pairs(cache) do
for query_name0, _ in pairs(cache[lang0]) do
M.invalidate_query_cache(lang0, query_name0)
end
end
else
error "Cannot have query_name by itself!"
end
end
end
-- This function is meant for an autocommand and not to be used. Only use if file is a query file.
---@param fname string
function M.invalidate_query_file(fname)
local fnamemodify = vim.fn.fnamemodify
M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r"))
end
---@class QueryInfo
---@field root TSNode
---@field source integer
@ -161,13 +14,13 @@ end
---@param root_lang string|nil
---@return Query|nil, QueryInfo|nil
local function prepare_query(bufnr, query_name, root, root_lang)
local buf_lang = parsers.get_buf_lang(bufnr)
local ft = vim.bo[bufnr].filetype
local buf_lang = vim.treesitter.language.get_lang(ft) or ft
if not buf_lang then
return
end
local parser = parsers.get_parser(bufnr, buf_lang)
local parser = vim.treesitter.get_parser(bufnr, buf_lang)
if not parser then
return
end
@ -198,7 +51,7 @@ local function prepare_query(bufnr, query_name, root, root_lang)
return
end
local query = M.get_query(root_lang, query_name)
local query = vim.treesitter.query.get(root_lang, query_name)
if not query then
return
end
@ -241,7 +94,7 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
---@return string[]
local function split(to_split)
local t = {}
for str in string.gmatch(to_split, "([^.]+)") do
for str in string.gmatch(to_split, '([^.]+)') do
table.insert(t, str)
end
@ -259,9 +112,9 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
for id, node in pairs(match) do
local name = query.captures[id] -- name of the capture in the query
if name ~= nil then
local path = split(name .. ".node")
local path = split(name .. '.node')
M.insert_to_path(prepared_match, path, node)
local metadata_path = split(name .. ".metadata")
local metadata_path = split(name .. '.metadata')
M.insert_to_path(prepared_match, metadata_path, metadata[id])
end
end
@ -272,16 +125,9 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
if preds then
for _, pred in pairs(preds) do
-- functions
if pred[1] == "set!" and type(pred[2]) == "string" then
if pred[1] == 'set!' and type(pred[2]) == 'string' then
M.insert_to_path(prepared_match, split(pred[2]), pred[3])
end
if pred[1] == "make-range!" and type(pred[2]) == "string" and #pred == 4 then
M.insert_to_path(
prepared_match,
split(pred[2] .. ".node"),
tsrange.TSRange.from_nodes(bufnr, match[pred[3]], match[pred[4]])
)
end
end
end
@ -291,103 +137,6 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
return iterator
end
-- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function.
--
---@param bufnr integer the buffer
---@param captures string|string[]
---@param query_group string the name of query group (highlights or injections for example)
---@param root TSNode|nil node from where to start the search
---@param lang string|nil the language from where to get the captures.
--- Root nodes can have several languages.
---@return table|nil
function M.get_capture_matches(bufnr, captures, query_group, root, lang)
if type(captures) == "string" then
captures = { captures }
end
local strip_captures = {} ---@type string[]
for i, capture in ipairs(captures) do
if capture:sub(1, 1) ~= "@" then
error 'Captures must start with "@"'
return
end
-- Remove leading "@".
strip_captures[i] = capture:sub(2)
end
local matches = {}
for match in M.iter_group_results(bufnr, query_group, root, lang) do
for _, capture in ipairs(strip_captures) do
local insert = utils.get_at_path(match, capture)
if insert then
table.insert(matches, insert)
end
end
end
return matches
end
function M.iter_captures(bufnr, query_name, root, lang)
local query, params = prepare_query(bufnr, query_name, root, lang)
if not query then
return EMPTY_ITER
end
assert(params)
local iter = query:iter_captures(params.root, params.source, params.start, params.stop)
local function wrapped_iter()
local id, node, metadata = iter()
if not id then
return
end
local name = query.captures[id]
if string.sub(name, 1, 1) == "_" then
return wrapped_iter()
end
return name, node, metadata
end
return wrapped_iter
end
---@param bufnr integer
---@param capture_string string
---@param query_group string
---@param filter_predicate fun(match: table): boolean
---@param scoring_function fun(match: table): number
---@param root TSNode
---@return table|unknown
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
if string.sub(capture_string, 1, 1) == "@" then
--remove leading "@"
capture_string = string.sub(capture_string, 2)
end
local best ---@type table|nil
local best_score ---@type number
for maybe_match in M.iter_group_results(bufnr, query_group, root) do
local match = utils.get_at_path(maybe_match, capture_string)
if match and filter_predicate(match) then
local current_score = scoring_function(match)
if not best then
best = match
best_score = current_score
end
if current_score > best_score then
best = match
best_score = current_score
end
end
end
return best
end
---Iterates matches from a query file.
---@param bufnr integer the buffer
---@param query_group string the query file to use
@ -413,41 +162,4 @@ function M.collect_group_results(bufnr, query_group, root, lang)
return matches
end
---@alias CaptureResFn function(string, LanguageTree, LanguageTree): string, string
-- Same as get_capture_matches except this will recursively get matches for every language in the tree.
---@param bufnr integer The buffer
---@param capture_or_fn string|CaptureResFn The capture to get. If a function is provided then that
--- function will be used to resolve both the capture and query argument.
--- The function can return `nil` to ignore that tree.
---@param query_type string? The query to get the capture from. This is ignored if a function is provided
--- for the capture argument.
---@return table[]
function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type)
---@type CaptureResFn
local type_fn
if type(capture_or_fn) == "function" then
type_fn = capture_or_fn
else
type_fn = function(_, _, _)
return capture_or_fn, query_type
end
end
local parser = parsers.get_parser(bufnr)
local matches = {}
if parser then
parser:for_each_tree(function(tree, lang_tree)
local lang = lang_tree:lang()
local capture, type_ = type_fn(lang, tree, lang_tree)
if capture then
vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang) or {})
end
end)
end
return matches
end
return M

View file

@ -1,238 +0,0 @@
local query = require "vim.treesitter.query"
local html_script_type_languages = {
["importmap"] = "json",
["module"] = "javascript",
["application/ecmascript"] = "javascript",
["text/ecmascript"] = "javascript",
}
local non_filetype_match_injection_language_aliases = {
ex = "elixir",
pl = "perl",
sh = "bash",
uxn = "uxntal",
ts = "typescript",
}
local function get_parser_from_markdown_info_string(injection_alias)
local match = vim.filetype.match { filename = "a." .. injection_alias }
return match or non_filetype_match_injection_language_aliases[injection_alias] or injection_alias
end
local function error(str)
vim.api.nvim_err_writeln(str)
end
local function valid_args(name, pred, count, strict_count)
local arg_count = #pred - 1
if strict_count then
if arg_count ~= count then
error(string.format("%s must have exactly %d arguments", name, count))
return false
end
elseif arg_count < count then
error(string.format("%s must have at least %d arguments", name, count))
return false
end
return true
end
---@param match (TSNode|nil)[]
---@param _pattern string
---@param _bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_predicate("nth?", function(match, _pattern, _bufnr, pred)
if not valid_args("nth?", pred, 2, true) then
return
end
local node = match[pred[2]] ---@type TSNode
local n = tonumber(pred[3])
if node and node:parent() and node:parent():named_child_count() > n then
return node:parent():named_child(n) == node
end
return false
end, true)
---@param match (TSNode|nil)[]
---@param _pattern string
---@param _bufnr integer
---@param pred string[]
---@return boolean|nil
local function has_ancestor(match, _pattern, _bufnr, pred)
if not valid_args(pred[1], pred, 2) then
return
end
local node = match[pred[2]]
local ancestor_types = { unpack(pred, 3) }
if not node then
return true
end
local just_direct_parent = pred[1]:find("has-parent", 1, true)
node = node:parent()
while node do
if vim.tbl_contains(ancestor_types, node:type()) then
return true
end
if just_direct_parent then
node = nil
else
node = node:parent()
end
end
return false
end
query.add_predicate("has-ancestor?", has_ancestor, true)
query.add_predicate("has-parent?", has_ancestor, true)
---@param match (TSNode|nil)[]
---@param _pattern string
---@param bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_predicate("is?", function(match, _pattern, bufnr, pred)
if not valid_args("is?", pred, 2) then
return
end
-- Avoid circular dependencies
local locals = require "nvim-treesitter.locals"
local node = match[pred[2]]
local types = { unpack(pred, 3) }
if not node then
return true
end
local _, _, kind = locals.find_definition(node, bufnr)
return vim.tbl_contains(types, kind)
end, true)
---@param match (TSNode|nil)[]
---@param _pattern string
---@param _bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_predicate("has-type?", function(match, _pattern, _bufnr, pred)
if not valid_args(pred[1], pred, 2) then
return
end
local node = match[pred[2]]
local types = { unpack(pred, 3) }
if not node then
return true
end
return vim.tbl_contains(types, node:type())
end, true)
---@param match (TSNode|nil)[]
---@param _ string
---@param bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_directive("set-lang-from-mimetype!", function(match, _, bufnr, pred, metadata)
local capture_id = pred[2]
local node = match[capture_id]
if not node then
return
end
local type_attr_value = vim.treesitter.get_node_text(node, bufnr)
local configured = html_script_type_languages[type_attr_value]
if configured then
metadata["injection.language"] = configured
else
local parts = vim.split(type_attr_value, "/", {})
metadata["injection.language"] = parts[#parts]
end
end, true)
---@param match (TSNode|nil)[]
---@param _ string
---@param bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_directive("set-lang-from-info-string!", function(match, _, bufnr, pred, metadata)
local capture_id = pred[2]
local node = match[capture_id]
if not node then
return
end
local injection_alias = vim.treesitter.get_node_text(node, bufnr)
metadata["injection.language"] = get_parser_from_markdown_info_string(injection_alias)
end, true)
-- Just avoid some annoying warnings for this directive
query.add_directive("make-range!", function() end, true)
--- transform node text to lower case (e.g., to make @injection.language case insensitive)
---
---@param match (TSNode|nil)[]
---@param _ string
---@param bufnr integer
---@param pred string[]
---@return boolean|nil
query.add_directive("downcase!", function(match, _, bufnr, pred, metadata)
local id = pred[2]
local node = match[id]
if not node then
return
end
local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ""
if not metadata[id] then
metadata[id] = {}
end
metadata[id].text = string.lower(text)
end, true)
-- Trim blank lines from end of the region
-- Arguments are the captures to trim.
---@param match (TSNode|nil)[]
---@param _ string
---@param bufnr integer
---@param pred string[]
---@param metadata table
query.add_directive("trim!", function(match, _, bufnr, pred, metadata)
for _, id in ipairs { select(2, unpack(pred)) } do
local node = match[id]
local start_row, start_col, end_row, end_col = node:range()
-- Don't trim if region ends in middle of a line
if end_col ~= 0 then
return
end
while true do
-- As we only care when end_col == 0, always inspect one line above end_row.
local end_line = vim.api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1]
if end_line ~= "" then
break
end
end_row = end_row - 1
end
-- If this produces an invalid range, we just skip it.
if start_row < end_row or (start_row == end_row and start_col <= end_col) then
if not metadata[id] then
metadata[id] = {}
end
metadata[id].range = { start_row, start_col, end_row, end_col }
end
end
end, true)

View file

@ -0,0 +1,254 @@
local uv = vim.loop
local utils = require('nvim-treesitter.utils')
local iswin = uv.os_uname().sysname == 'Windows_NT'
local M = {}
---@param executables string[]
---@return string|nil
function M.select_executable(executables)
return vim.tbl_filter(function(c) ---@param c string
return c ~= vim.NIL and vim.fn.executable(c) == 1
end, executables)[1]
end
-- Returns the compiler arguments based on the compiler and OS
---@param repo InstallInfo
---@param compiler string
---@return string[]
function M.select_compiler_args(repo, compiler)
if compiler:find('cl$') or compiler:find('cl.exe$') then
return {
'/Fe:',
'parser.so',
'/Isrc',
repo.files,
'-Os',
'/LD',
}
elseif compiler:find('zig$') or compiler:find('zig.exe$') then
return {
'c++',
'-o',
'parser.so',
repo.files,
'-lc',
'-Isrc',
'-shared',
'-Os',
}
else
local args = {
'-o',
'parser.so',
'-I./src',
repo.files,
'-Os',
}
if uv.os_uname().sysname == 'Darwin' then
table.insert(args, '-bundle')
else
table.insert(args, '-shared')
end
if
#vim.tbl_filter(function(file) ---@param file string
local ext = vim.fn.fnamemodify(file, ':e')
return ext == 'cc' or ext == 'cpp' or ext == 'cxx'
end, repo.files) > 0
then
table.insert(args, '-lstdc++')
end
if not iswin then
table.insert(args, '-fPIC')
end
return args
end
end
-- Returns the compile command based on the OS and user options
---@param repo InstallInfo
---@param cc string
---@param compile_location string
---@return Command
function M.select_compile_command(repo, cc, compile_location)
local make = M.select_executable({ 'gmake', 'make' })
if cc:find('cl$') or cc:find('cl.exe$') or not repo.use_makefile or iswin or not make then
return {
cmd = cc,
info = 'Compiling...',
err = 'Error during compilation',
opts = {
args = vim.tbl_flatten(M.select_compiler_args(repo, cc)),
cwd = compile_location,
},
}
else
return {
cmd = make,
info = 'Compiling...',
err = 'Error during compilation',
opts = {
args = {
'--makefile=' .. utils.get_package_path('scripts', 'compile_parsers.makefile'),
'CC=' .. cc,
},
cwd = compile_location,
},
}
end
end
---@param repo InstallInfo
---@param project_name string
---@param cache_dir string
---@param revision string|nil
---@param prefer_git boolean
---@return table
function M.select_download_commands(repo, project_name, cache_dir, revision, prefer_git)
local can_use_tar = vim.fn.executable('tar') == 1 and vim.fn.executable('curl') == 1
local is_github = repo.url:find('github.com', 1, true)
local is_gitlab = repo.url:find('gitlab.com', 1, true)
local project_dir = utils.join_path(cache_dir, project_name)
revision = revision or repo.branch or 'master'
if can_use_tar and (is_github or is_gitlab) and not prefer_git then
local url = repo.url:gsub('.git$', '')
local dir_rev = revision
if is_github and revision:find('^v%d') then
dir_rev = revision:sub(2)
end
local temp_dir = project_dir .. '-tmp'
return {
{
cmd = function()
vim.fn.delete(temp_dir, 'rf')
end,
},
{
cmd = 'curl',
info = 'Downloading ' .. project_name .. '...',
err = 'Error during download, please verify your internet connection',
opts = {
args = {
'--silent',
'-L', -- follow redirects
is_github and url .. '/archive/' .. revision .. '.tar.gz'
or url
.. '/-/archive/'
.. revision
.. '/'
.. project_name
.. '-'
.. revision
.. '.tar.gz',
'--output',
project_name .. '.tar.gz',
},
cwd = cache_dir,
},
},
{
cmd = function()
--TODO(clason): use vim.fn.mkdir(temp_dir, 'p') in case stdpath('cache') is not created
uv.fs_mkdir(temp_dir, 493)
end,
info = 'Creating temporary directory',
err = 'Could not create ' .. project_name .. '-tmp',
},
{
cmd = 'tar',
info = 'Extracting ' .. project_name .. '...',
err = 'Error during tarball extraction.',
opts = {
args = {
'-xvzf',
project_name .. '.tar.gz',
'-C',
project_name .. '-tmp',
},
cwd = cache_dir,
},
},
{
cmd = function()
uv.fs_unlink(project_dir .. '.tar.gz')
end,
},
{
cmd = function()
uv.fs_rename(
utils.join_path(temp_dir, url:match('[^/]-$') .. '-' .. dir_rev),
project_dir
)
end,
},
{
cmd = function()
vim.fn.delete(temp_dir, 'rf')
end,
},
}
else
local git_dir = project_dir
local clone_error = 'Error during download, please verify your internet connection'
return {
{
cmd = 'git',
info = 'Downloading ' .. project_name .. '...',
err = clone_error,
opts = {
args = {
'clone',
repo.url,
project_name,
},
cwd = cache_dir,
},
},
{
cmd = 'git',
info = 'Checking out locked revision',
err = 'Error while checking out revision',
opts = {
args = {
'checkout',
revision,
},
cwd = git_dir,
},
},
}
end
end
--TODO(clason): only needed for iter_cmd_sync -> replace with uv.spawn?
-- Convert path for cmd.exe on Windows (needed when shellslash is set)
---@param p string
---@return string
local function cmdpath(p)
return vim.o.shellslash and p:gsub('/', '\\') or p
end
---@param dir string
---@param command string
---@return string command
function M.make_directory_change_for_command(dir, command)
if iswin then
if string.find(vim.o.shell, 'cmd') ~= nil then
return string.format('pushd %s & %s & popd', cmdpath(dir), command)
else
return string.format('pushd %s ; %s ; popd', cmdpath(dir), command)
end
else
return string.format('cd %s;\n %s', dir, command)
end
end
return M

View file

@ -1,335 +0,0 @@
local fn = vim.fn
local utils = require "nvim-treesitter.utils"
-- Convert path for cmd.exe on Windows.
-- This is needed when vim.opt.shellslash is in use.
---@param p string
---@return string
local function cmdpath(p)
if vim.opt.shellslash:get() then
local r = p:gsub("/", "\\")
return r
else
return p
end
end
local M = {}
-- Returns the mkdir command based on the OS
---@param directory string
---@param cwd string
---@param info_msg string
---@return table
function M.select_mkdir_cmd(directory, cwd, info_msg)
if fn.has "win32" == 1 then
return {
cmd = "cmd",
opts = {
args = { "/C", "mkdir", cmdpath(directory) },
cwd = cwd,
},
info = info_msg,
err = "Could not create " .. directory,
}
else
return {
cmd = "mkdir",
opts = {
args = { directory },
cwd = cwd,
},
info = info_msg,
err = "Could not create " .. directory,
}
end
end
-- Returns the remove command based on the OS
---@param file string
---@param info_msg string
---@return table
function M.select_rm_file_cmd(file, info_msg)
if fn.has "win32" == 1 then
return {
cmd = "cmd",
opts = {
args = { "/C", "if", "exist", cmdpath(file), "del", cmdpath(file) },
},
info = info_msg,
err = "Could not delete " .. file,
}
else
return {
cmd = "rm",
opts = {
args = { file },
},
info = info_msg,
err = "Could not delete " .. file,
}
end
end
---@param executables string[]
---@return string|nil
function M.select_executable(executables)
return vim.tbl_filter(function(c) ---@param c string
return c ~= vim.NIL and fn.executable(c) == 1
end, executables)[1]
end
-- Returns the compiler arguments based on the compiler and OS
---@param repo InstallInfo
---@param compiler string
---@return string[]
function M.select_compiler_args(repo, compiler)
if string.match(compiler, "cl$") or string.match(compiler, "cl.exe$") then
return {
"/Fe:",
"parser.so",
"/Isrc",
repo.files,
"-Os",
"/LD",
}
elseif string.match(compiler, "zig$") or string.match(compiler, "zig.exe$") then
return {
"c++",
"-o",
"parser.so",
repo.files,
"-lc",
"-Isrc",
"-shared",
"-Os",
}
else
local args = {
"-o",
"parser.so",
"-I./src",
repo.files,
"-Os",
}
if fn.has "mac" == 1 then
table.insert(args, "-bundle")
else
table.insert(args, "-shared")
end
if
#vim.tbl_filter(function(file) ---@param file string
local ext = vim.fn.fnamemodify(file, ":e")
return ext == "cc" or ext == "cpp" or ext == "cxx"
end, repo.files) > 0
then
table.insert(args, "-lstdc++")
end
if fn.has "win32" == 0 then
table.insert(args, "-fPIC")
end
return args
end
end
-- Returns the compile command based on the OS and user options
---@param repo InstallInfo
---@param cc string
---@param compile_location string
---@return Command
function M.select_compile_command(repo, cc, compile_location)
local make = M.select_executable { "gmake", "make" }
if
string.match(cc, "cl$")
or string.match(cc, "cl.exe$")
or not repo.use_makefile
or fn.has "win32" == 1
or not make
then
return {
cmd = cc,
info = "Compiling...",
err = "Error during compilation",
opts = {
args = vim.tbl_flatten(M.select_compiler_args(repo, cc)),
cwd = compile_location,
},
}
else
return {
cmd = make,
info = "Compiling...",
err = "Error during compilation",
opts = {
args = {
"--makefile=" .. utils.join_path(utils.get_package_path(), "scripts", "compile_parsers.makefile"),
"CC=" .. cc,
"CXX_STANDARD=" .. (repo.cxx_standard or "c++14"),
},
cwd = compile_location,
},
}
end
end
-- Returns the remove command based on the OS
---@param cache_folder string
---@param project_name string
---@return Command
function M.select_install_rm_cmd(cache_folder, project_name)
if fn.has "win32" == 1 then
local dir = cache_folder .. "\\" .. project_name
return {
cmd = "cmd",
opts = {
args = { "/C", "if", "exist", cmdpath(dir), "rmdir", "/s", "/q", cmdpath(dir) },
},
}
else
return {
cmd = "rm",
opts = {
args = { "-rf", cache_folder .. "/" .. project_name },
},
}
end
end
-- Returns the move command based on the OS
---@param from string
---@param to string
---@param cwd string
---@return Command
function M.select_mv_cmd(from, to, cwd)
if fn.has "win32" == 1 then
return {
cmd = "cmd",
opts = {
args = { "/C", "move", "/Y", cmdpath(from), cmdpath(to) },
cwd = cwd,
},
}
else
return {
cmd = "mv",
opts = {
args = { "-f", from, to },
cwd = cwd,
},
}
end
end
---@param repo InstallInfo
---@param project_name string
---@param cache_folder string
---@param revision string|nil
---@param prefer_git boolean
---@return table
function M.select_download_commands(repo, project_name, cache_folder, revision, prefer_git)
local can_use_tar = vim.fn.executable "tar" == 1 and vim.fn.executable "curl" == 1
local is_github = repo.url:find("github.com", 1, true)
local is_gitlab = repo.url:find("gitlab.com", 1, true)
revision = revision or repo.branch or "master"
if can_use_tar and (is_github or is_gitlab) and not prefer_git then
local path_sep = utils.get_path_sep()
local url = repo.url:gsub(".git$", "")
local folder_rev = revision
if is_github and revision:match "^v%d" then
folder_rev = revision:sub(2)
end
return {
M.select_install_rm_cmd(cache_folder, project_name .. "-tmp"),
{
cmd = "curl",
info = "Downloading " .. project_name .. "...",
err = "Error during download, please verify your internet connection",
opts = {
args = {
"--silent",
"-L", -- follow redirects
is_github and url .. "/archive/" .. revision .. ".tar.gz"
or url .. "/-/archive/" .. revision .. "/" .. project_name .. "-" .. revision .. ".tar.gz",
"--output",
project_name .. ".tar.gz",
},
cwd = cache_folder,
},
},
M.select_mkdir_cmd(project_name .. "-tmp", cache_folder, "Creating temporary directory"),
{
cmd = "tar",
info = "Extracting " .. project_name .. "...",
err = "Error during tarball extraction.",
opts = {
args = {
"-xvzf",
project_name .. ".tar.gz",
"-C",
project_name .. "-tmp",
},
cwd = cache_folder,
},
},
M.select_rm_file_cmd(cache_folder .. path_sep .. project_name .. ".tar.gz"),
M.select_mv_cmd(
utils.join_path(project_name .. "-tmp", url:match "[^/]-$" .. "-" .. folder_rev),
project_name,
cache_folder
),
M.select_install_rm_cmd(cache_folder, project_name .. "-tmp"),
}
else
local git_folder = utils.join_path(cache_folder, project_name)
local clone_error = "Error during download, please verify your internet connection"
return {
{
cmd = "git",
info = "Downloading " .. project_name .. "...",
err = clone_error,
opts = {
args = {
"clone",
repo.url,
project_name,
},
cwd = cache_folder,
},
},
{
cmd = "git",
info = "Checking out locked revision",
err = "Error while checking out revision",
opts = {
args = {
"checkout",
revision,
},
cwd = git_folder,
},
},
}
end
end
---@param dir string
---@param command string
---@return string command
function M.make_directory_change_for_command(dir, command)
if fn.has "win32" == 1 then
if string.find(vim.o.shell, "cmd") ~= nil then
return string.format("pushd %s & %s & popd", cmdpath(dir), command)
else
return string.format("pushd %s ; %s ; popd", cmdpath(dir), command)
end
else
return string.format("cd %s;\n %s", dir, command)
end
end
return M

View file

@ -1,53 +0,0 @@
local parsers = require "nvim-treesitter.parsers"
local ts_utils = require "nvim-treesitter.ts_utils"
local M = {}
-- Trim spaces and opening brackets from end
local transform_line = function(line)
return line:gsub("%s*[%[%(%{]*%s*$", "")
end
function M.statusline(opts)
if not parsers.has_parser() then
return
end
local options = opts or {}
if type(opts) == "number" then
options = { indicator_size = opts }
end
local bufnr = options.bufnr or 0
local indicator_size = options.indicator_size or 100
local type_patterns = options.type_patterns or { "class", "function", "method" }
local transform_fn = options.transform_fn or transform_line
local separator = options.separator or " -> "
local allow_duplicates = options.allow_duplicates or false
local current_node = ts_utils.get_node_at_cursor()
if not current_node then
return ""
end
local lines = {}
local expr = current_node
while expr do
local line = ts_utils._get_line_for_node(expr, type_patterns, transform_fn, bufnr)
if line ~= "" then
if allow_duplicates or not vim.tbl_contains(lines, line) then
table.insert(lines, 1, line)
end
end
expr = expr:parent()
end
local text = table.concat(lines, separator)
local text_len = #text
if text_len > indicator_size then
return "..." .. text:sub(text_len - indicator_size, text_len)
end
return text
end
return M

View file

@ -1,467 +0,0 @@
local api = vim.api
local parsers = require "nvim-treesitter.parsers"
local utils = require "nvim-treesitter.utils"
local ts = vim.treesitter
local M = {}
local function get_node_text(node, bufnr)
bufnr = bufnr or api.nvim_get_current_buf()
if not node then
return {}
end
-- We have to remember that end_col is end-exclusive
local start_row, start_col, end_row, end_col = ts.get_node_range(node)
if start_row ~= end_row then
local lines = api.nvim_buf_get_lines(bufnr, start_row, end_row + 1, false)
if next(lines) == nil then
return {}
end
lines[1] = string.sub(lines[1], start_col + 1)
-- end_row might be just after the last line. In this case the last line is not truncated.
if #lines == end_row - start_row + 1 then
lines[#lines] = string.sub(lines[#lines], 1, end_col)
end
return lines
else
local line = api.nvim_buf_get_lines(bufnr, start_row, start_row + 1, false)[1]
-- If line is nil then the line is empty
return line and { string.sub(line, start_col + 1, end_col) } or {}
end
end
---@private
---@param node TSNode
---@param type_patterns string[]
---@param transform_fn fun(line: string): string
---@param bufnr integer
---@return string
function M._get_line_for_node(node, type_patterns, transform_fn, bufnr)
local node_type = node:type()
local is_valid = false
for _, rgx in ipairs(type_patterns) do
if node_type:find(rgx) then
is_valid = true
break
end
end
if not is_valid then
return ""
end
local line = transform_fn(vim.trim(get_node_text(node, bufnr)[1] or ""), node)
-- Escape % to avoid statusline to evaluate content as expression
return line:gsub("%%", "%%%%")
end
-- Gets the actual text content of a node
-- @deprecated Use vim.treesitter.query.get_node_text
-- @param node the node to get the text from
-- @param bufnr the buffer containing the node
-- @return list of lines of text of the node
function M.get_node_text(node, bufnr)
vim.notify_once(
"nvim-treesitter.ts_utils.get_node_text is deprecated: use vim.treesitter.query.get_node_text",
vim.log.levels.WARN
)
return get_node_text(node, bufnr)
end
-- Determines whether a node is the parent of another
-- @param dest the possible parent
-- @param source the possible child node
function M.is_parent(dest, source)
if not (dest and source) then
return false
end
local current = source
while current ~= nil do
if current == dest then
return true
end
current = current:parent()
end
return false
end
-- Get next node with same parent
---@param node TSNode
---@param allow_switch_parents? boolean allow switching parents if last node
---@param allow_next_parent? boolean allow next parent if last node and next parent without children
function M.get_next_node(node, allow_switch_parents, allow_next_parent)
local destination_node ---@type TSNode
local parent = node:parent()
if not parent then
return
end
local found_pos = 0
for i = 0, parent:named_child_count() - 1, 1 do
if parent:named_child(i) == node then
found_pos = i
break
end
end
if parent:named_child_count() > found_pos + 1 then
destination_node = parent:named_child(found_pos + 1)
elseif allow_switch_parents then
local next_node = M.get_next_node(node:parent())
if next_node and next_node:named_child_count() > 0 then
destination_node = next_node:named_child(0)
elseif next_node and allow_next_parent then
destination_node = next_node
end
end
return destination_node
end
-- Get previous node with same parent
---@param node TSNode
---@param allow_switch_parents? boolean allow switching parents if first node
---@param allow_previous_parent? boolean allow previous parent if first node and previous parent without children
function M.get_previous_node(node, allow_switch_parents, allow_previous_parent)
local destination_node ---@type TSNode
local parent = node:parent()
if not parent then
return
end
local found_pos = 0
for i = 0, parent:named_child_count() - 1, 1 do
if parent:named_child(i) == node then
found_pos = i
break
end
end
if 0 < found_pos then
destination_node = parent:named_child(found_pos - 1)
elseif allow_switch_parents then
local previous_node = M.get_previous_node(node:parent())
if previous_node and previous_node:named_child_count() > 0 then
destination_node = previous_node:named_child(previous_node:named_child_count() - 1)
elseif previous_node and allow_previous_parent then
destination_node = previous_node
end
end
return destination_node
end
function M.get_named_children(node)
local nodes = {} ---@type TSNode[]
for i = 0, node:named_child_count() - 1, 1 do
nodes[i + 1] = node:named_child(i)
end
return nodes
end
function M.get_node_at_cursor(winnr, ignore_injected_langs)
winnr = winnr or 0
local cursor = api.nvim_win_get_cursor(winnr)
local cursor_range = { cursor[1] - 1, cursor[2] }
local buf = vim.api.nvim_win_get_buf(winnr)
local root_lang_tree = parsers.get_parser(buf)
if not root_lang_tree then
return
end
local root ---@type TSNode|nil
if ignore_injected_langs then
for _, tree in ipairs(root_lang_tree:trees()) do
local tree_root = tree:root()
if tree_root and ts.is_in_node_range(tree_root, cursor_range[1], cursor_range[2]) then
root = tree_root
break
end
end
else
root = M.get_root_for_position(cursor_range[1], cursor_range[2], root_lang_tree)
end
if not root then
return
end
return root:named_descendant_for_range(cursor_range[1], cursor_range[2], cursor_range[1], cursor_range[2])
end
function M.get_root_for_position(line, col, root_lang_tree)
if not root_lang_tree then
if not parsers.has_parser() then
return
end
root_lang_tree = parsers.get_parser()
end
local lang_tree = root_lang_tree:language_for_range { line, col, line, col }
for _, tree in ipairs(lang_tree:trees()) do
local root = tree:root()
if root and ts.is_in_node_range(root, line, col) then
return root, tree, lang_tree
end
end
-- This isn't a likely scenario, since the position must belong to a tree somewhere.
return nil, nil, lang_tree
end
---comment
---@param node TSNode
---@return TSNode result
function M.get_root_for_node(node)
local parent = node
local result = node
while parent ~= nil do
result = parent
parent = result:parent()
end
return result
end
function M.highlight_node(node, buf, hl_namespace, hl_group)
if not node then
return
end
M.highlight_range({ node:range() }, buf, hl_namespace, hl_group)
end
-- Get a compatible vim range (1 index based) from a TS node range.
--
-- TS nodes start with 0 and the end col is ending exclusive.
-- They also treat a EOF/EOL char as a char ending in the first
-- col of the next row.
---comment
---@param range integer[]
---@param buf integer|nil
---@return integer, integer, integer, integer
function M.get_vim_range(range, buf)
---@type integer, integer, integer, integer
local srow, scol, erow, ecol = unpack(range)
srow = srow + 1
scol = scol + 1
erow = erow + 1
if ecol == 0 then
-- Use the value of the last col of the previous row instead.
erow = erow - 1
if not buf or buf == 0 then
ecol = vim.fn.col { erow, "$" } - 1
else
ecol = #api.nvim_buf_get_lines(buf, erow - 1, erow, false)[1]
end
ecol = math.max(ecol, 1)
end
return srow, scol, erow, ecol
end
function M.highlight_range(range, buf, hl_namespace, hl_group)
---@type integer, integer, integer, integer
local start_row, start_col, end_row, end_col = unpack(range)
---@diagnostic disable-next-line: missing-parameter
vim.highlight.range(buf, hl_namespace, hl_group, { start_row, start_col }, { end_row, end_col })
end
-- Set visual selection to node
-- @param selection_mode One of "charwise" (default) or "v", "linewise" or "V",
-- "blockwise" or "<C-v>" (as a string with 5 characters or a single character)
function M.update_selection(buf, node, selection_mode)
local start_row, start_col, end_row, end_col = M.get_vim_range({ ts.get_node_range(node) }, buf)
local v_table = { charwise = "v", linewise = "V", blockwise = "<C-v>" }
selection_mode = selection_mode or "charwise"
-- Normalise selection_mode
if vim.tbl_contains(vim.tbl_keys(v_table), selection_mode) then
selection_mode = v_table[selection_mode]
end
-- enter visual mode if normal or operator-pending (no) mode
-- Why? According to https://learnvimscriptthehardway.stevelosh.com/chapters/15.html
-- If your operator-pending mapping ends with some text visually selected, Vim will operate on that text.
-- Otherwise, Vim will operate on the text between the original cursor position and the new position.
local mode = api.nvim_get_mode()
if mode.mode ~= selection_mode then
-- Call to `nvim_replace_termcodes()` is needed for sending appropriate command to enter blockwise mode
selection_mode = vim.api.nvim_replace_termcodes(selection_mode, true, true, true)
api.nvim_cmd({ cmd = "normal", bang = true, args = { selection_mode } }, {})
end
api.nvim_win_set_cursor(0, { start_row, start_col - 1 })
vim.cmd "normal! o"
api.nvim_win_set_cursor(0, { end_row, end_col - 1 })
end
-- Byte length of node range
---@param node TSNode
---@return number
function M.node_length(node)
local _, _, start_byte = node:start()
local _, _, end_byte = node:end_()
return end_byte - start_byte
end
---@deprecated Use `vim.treesitter.is_in_node_range()` instead
function M.is_in_node_range(node, line, col)
vim.notify_once(
"nvim-treesitter.ts_utils.is_in_node_range is deprecated: use vim.treesitter.is_in_node_range",
vim.log.levels.WARN
)
return ts.is_in_node_range(node, line, col)
end
---@deprecated Use `vim.treesitter.get_node_range()` instead
function M.get_node_range(node_or_range)
vim.notify_once(
"nvim-treesitter.ts_utils.get_node_range is deprecated: use vim.treesitter.get_node_range",
vim.log.levels.WARN
)
return ts.get_node_range(node_or_range)
end
---@param node TSNode
---@return table
function M.node_to_lsp_range(node)
local start_line, start_col, end_line, end_col = ts.get_node_range(node)
local rtn = {}
rtn.start = { line = start_line, character = start_col }
rtn["end"] = { line = end_line, character = end_col }
return rtn
end
-- Memoizes a function based on the buffer tick of the provided bufnr.
-- The cache entry is cleared when the buffer is detached to avoid memory leaks.
-- The options argument is a table with two optional values:
-- - bufnr: extracts a bufnr from the given arguments.
-- - key: extracts the cache key from the given arguments.
---@param fn function the fn to memoize, taking the buffer as first argument
---@param options? {bufnr: integer?, key: string|fun(...): string?} the memoization options
---@return function: a memoized function
function M.memoize_by_buf_tick(fn, options)
options = options or {}
---@type table<string, {result: any, last_tick: integer}>
local cache = setmetatable({}, { __mode = "kv" })
local bufnr_fn = utils.to_func(options.bufnr or utils.identity)
local key_fn = utils.to_func(options.key or utils.identity)
return function(...)
local bufnr = bufnr_fn(...)
local key = key_fn(...)
local tick = api.nvim_buf_get_changedtick(bufnr)
if cache[key] then
if cache[key].last_tick == tick then
return cache[key].result
end
else
local function detach_handler()
cache[key] = nil
end
-- Clean up logic only!
api.nvim_buf_attach(bufnr, false, {
on_detach = detach_handler,
on_reload = detach_handler,
})
end
cache[key] = {
result = fn(...),
last_tick = tick,
}
return cache[key].result
end
end
function M.swap_nodes(node_or_range1, node_or_range2, bufnr, cursor_to_second)
if not node_or_range1 or not node_or_range2 then
return
end
local range1 = M.node_to_lsp_range(node_or_range1)
local range2 = M.node_to_lsp_range(node_or_range2)
local text1 = get_node_text(node_or_range1, bufnr)
local text2 = get_node_text(node_or_range2, bufnr)
local edit1 = { range = range1, newText = table.concat(text2, "\n") }
local edit2 = { range = range2, newText = table.concat(text1, "\n") }
vim.lsp.util.apply_text_edits({ edit1, edit2 }, bufnr, "utf-8")
if cursor_to_second then
utils.set_jump()
local char_delta = 0
local line_delta = 0
if
range1["end"].line < range2.start.line
or (range1["end"].line == range2.start.line and range1["end"].character <= range2.start.character)
then
line_delta = #text2 - #text1
end
if range1["end"].line == range2.start.line and range1["end"].character <= range2.start.character then
if line_delta ~= 0 then
--- why?
--correction_after_line_change = -range2.start.character
--text_now_before_range2 = #(text2[#text2])
--space_between_ranges = range2.start.character - range1["end"].character
--char_delta = correction_after_line_change + text_now_before_range2 + space_between_ranges
--- Equivalent to:
char_delta = #text2[#text2] - range1["end"].character
-- add range1.start.character if last line of range1 (now text2) does not start at 0
if range1.start.line == range2.start.line + line_delta then
char_delta = char_delta + range1.start.character
end
else
char_delta = #text2[#text2] - #text1[#text1]
end
end
api.nvim_win_set_cursor(
api.nvim_get_current_win(),
{ range2.start.line + 1 + line_delta, range2.start.character + char_delta }
)
end
end
function M.goto_node(node, goto_end, avoid_set_jump)
if not node then
return
end
if not avoid_set_jump then
utils.set_jump()
end
local range = { M.get_vim_range { node:range() } }
---@type table<number>
local position
if not goto_end then
position = { range[1], range[2] }
else
position = { range[3], range[4] }
end
-- Enter visual mode if we are in operator pending mode
-- If we don't do this, it will miss the last character.
local mode = vim.api.nvim_get_mode()
if mode.mode == "no" then
vim.cmd "normal! v"
end
-- Position is 1, 0 indexed.
api.nvim_win_set_cursor(0, { position[1], position[2] - 1 })
end
return M

View file

@ -1,154 +0,0 @@
local M = {}
local TSRange = {}
TSRange.__index = TSRange
local api = vim.api
local ts_utils = require "nvim-treesitter.ts_utils"
local parsers = require "nvim-treesitter.parsers"
local function get_byte_offset(buf, row, col)
return api.nvim_buf_get_offset(buf, row) + vim.fn.byteidx(api.nvim_buf_get_lines(buf, row, row + 1, false)[1], col)
end
function TSRange.new(buf, start_row, start_col, end_row, end_col)
return setmetatable({
start_pos = { start_row, start_col, get_byte_offset(buf, start_row, start_col) },
end_pos = { end_row, end_col, get_byte_offset(buf, end_row, end_col) },
buf = buf,
[1] = start_row,
[2] = start_col,
[3] = end_row,
[4] = end_col,
}, TSRange)
end
function TSRange.from_nodes(buf, start_node, end_node)
TSRange.__index = TSRange
local start_pos = start_node and { start_node:start() } or { end_node:start() }
local end_pos = end_node and { end_node:end_() } or { start_node:end_() }
return setmetatable({
start_pos = { start_pos[1], start_pos[2], start_pos[3] },
end_pos = { end_pos[1], end_pos[2], end_pos[3] },
buf = buf,
[1] = start_pos[1],
[2] = start_pos[2],
[3] = end_pos[1],
[4] = end_pos[2],
}, TSRange)
end
function TSRange.from_table(buf, range)
return setmetatable({
start_pos = { range[1], range[2], get_byte_offset(buf, range[1], range[2]) },
end_pos = { range[3], range[4], get_byte_offset(buf, range[3], range[4]) },
buf = buf,
[1] = range[1],
[2] = range[2],
[3] = range[3],
[4] = range[4],
}, TSRange)
end
function TSRange:parent()
local root_lang_tree = parsers.get_parser(self.buf)
local root = ts_utils.get_root_for_position(self[1], self[2], root_lang_tree)
return root
and root:named_descendant_for_range(self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2])
or nil
end
function TSRange:field() end
function TSRange:child_count()
return #self:collect_children()
end
function TSRange:named_child_count()
return #self:collect_children(function(c)
return c:named()
end)
end
function TSRange:iter_children()
local raw_iterator = self:parent().iter_children()
return function()
while true do
local node = raw_iterator()
if not node then
return
end
local _, _, start_byte = node:start()
local _, _, end_byte = node:end_()
if start_byte >= self.start_pos[3] and end_byte <= self.end_pos[3] then
return node
end
end
end
end
function TSRange:collect_children(filter_fun)
local children = {}
for _, c in self:iter_children() do
if not filter_fun or filter_fun(c) then
table.insert(children, c)
end
end
return children
end
function TSRange:child(index)
return self:collect_children()[index + 1]
end
function TSRange:named_child(index)
return self:collect_children(function(c)
return c.named()
end)[index + 1]
end
function TSRange:start()
return unpack(self.start_pos)
end
function TSRange:end_()
return unpack(self.end_pos)
end
function TSRange:range()
return self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2]
end
function TSRange:type()
return "nvim-treesitter-range"
end
function TSRange:symbol()
return -1
end
function TSRange:named()
return false
end
function TSRange:missing()
return false
end
function TSRange:has_error()
return #self:collect_children(function(c)
return c:has_error()
end) > 0 and true or false
end
function TSRange:sexpr()
return table.concat(
vim.tbl_map(function(c)
return c:sexpr()
end, self:collect_children()),
" "
)
end
M.TSRange = TSRange
return M

View file

@ -1,237 +1,12 @@
local api = vim.api
local fn = vim.fn
local luv = vim.loop
local M = {}
-- Wrapper around vim.notify with common options set.
---@param msg string
---@param log_level number|nil
---@param opts table|nil
function M.notify(msg, log_level, opts)
local default_opts = { title = "nvim-treesitter" }
vim.notify(msg, log_level, vim.tbl_extend("force", default_opts, opts or {}))
--TODO(clason): replace by vim.fs._join_paths
function M.join_path(...)
return (table.concat({ ... }, '/'):gsub('//+', '/'))
end
-- Returns the system-specific path separator.
---@return string
function M.get_path_sep()
return (fn.has "win32" == 1 and not vim.opt.shellslash:get()) and "\\" or "/"
end
-- Returns a function that joins the given arguments with separator. Arguments
-- can't be nil. Example:
--
--[[
print(M.generate_join(" ")("foo", "bar"))
--]]
--prints "foo bar"
---@param separator string
---@return fun(...: string): string
function M.generate_join(separator)
return function(...)
return table.concat({ ... }, separator)
end
end
M.join_path = M.generate_join(M.get_path_sep())
M.join_space = M.generate_join " "
---@class Command
---@field run function
---@field f_args string
---@field args string
-- Define user defined vim command which calls nvim-treesitter module function
-- - If module name is 'mod', it should be defined in hierarchy 'nvim-treesitter.mod'
-- - A table with name 'commands' should be defined in 'mod' which needs to be passed as
-- the commands param of this function
--
---@param mod string Name of the module that resides in the hierarchy - nvim-treesitter.module
---@param commands table<string, Command> Command list for the module
--- - {command_name} Name of the vim user defined command, Keys:
--- - {run}: (function) callback function that needs to be executed
--- - {f_args}: (string, default <f-args>)
--- - type of arguments that needs to be passed to the vim command
--- - {args}: (string, optional)
--- - vim command attributes
---
---* @example
--- If module is nvim-treesitter.custom_mod
--- <pre>
--- M.commands = {
--- custom_command = {
--- run = M.module_function,
--- f_args = "<f-args>",
--- args = {
--- "-range"
--- }
--- }
--- }
---
--- utils.setup_commands("custom_mod", require("nvim-treesitter.custom_mod").commands)
--- </pre>
---
--- Will generate command :
--- <pre>
--- command! -range custom_command \
--- lua require'nvim-treesitter.custom_mod'.commands.custom_command['run<bang>'](<f-args>)
--- </pre>
function M.setup_commands(mod, commands)
for command_name, def in pairs(commands) do
local f_args = def.f_args or "<f-args>"
local call_fn =
string.format("lua require'nvim-treesitter.%s'.commands.%s['run<bang>'](%s)", mod, command_name, f_args)
local parts = vim.tbl_flatten {
"command!",
"-bar",
def.args,
command_name,
call_fn,
}
api.nvim_command(table.concat(parts, " "))
end
end
---@param dir string
---@param create_err string
---@param writeable_err string
---@return string|nil, string|nil
function M.create_or_reuse_writable_dir(dir, create_err, writeable_err)
create_err = create_err or M.join_space("Could not create dir '", dir, "': ")
writeable_err = writeable_err or M.join_space("Invalid rights, '", dir, "' should be read/write")
-- Try creating and using parser_dir if it doesn't exist
if not luv.fs_stat(dir) then
local ok, error = pcall(vim.fn.mkdir, dir, "p", "0755")
if not ok then
return nil, M.join_space(create_err, error)
end
return dir
end
-- parser_dir exists, use it if it's read/write
if luv.fs_access(dir, "RW") then
return dir
end
-- parser_dir exists but isn't read/write, give up
return nil, M.join_space(writeable_err, dir, "'")
end
function M.get_package_path()
-- Path to this source file, removing the leading '@'
local source = string.sub(debug.getinfo(1, "S").source, 2)
-- Path to the package root
return fn.fnamemodify(source, ":p:h:h:h")
end
function M.get_cache_dir()
local cache_dir = fn.stdpath "data"
if luv.fs_access(cache_dir, "RW") then
return cache_dir
elseif luv.fs_access("/tmp", "RW") then
return "/tmp"
end
return nil, M.join_space("Invalid cache rights,", fn.stdpath "data", "or /tmp should be read/write")
end
-- Returns $XDG_DATA_HOME/nvim/site, but could use any directory that is in
-- runtimepath
function M.get_site_dir()
return M.join_path(fn.stdpath "data", "site")
end
-- Gets a property at path
---@param tbl table the table to access
---@param path string the '.' separated path
---@return table|nil result the value at path or nil
function M.get_at_path(tbl, path)
if path == "" then
return tbl
end
local segments = vim.split(path, ".", true)
---@type table[]|table
local result = tbl
for _, segment in ipairs(segments) do
if type(result) == "table" then
---@type table
-- TODO: figure out the actual type of tbl
result = result[segment]
end
end
return result
end
function M.set_jump()
vim.cmd "normal! m'"
end
-- Filters a list based on the given predicate
---@param tbl any[] The list to filter
---@param predicate fun(v:any, i:number):boolean The predicate to filter with
function M.filter(tbl, predicate)
local result = {}
for i, v in ipairs(tbl) do
if predicate(v, i) then
table.insert(result, v)
end
end
return result
end
-- Returns a list of all values from the first list
-- that are not present in the second list.
---@param tbl1 any[] The first table
---@param tbl2 any[] The second table
---@return table
function M.difference(tbl1, tbl2)
return M.filter(tbl1, function(v)
return not vim.tbl_contains(tbl2, v)
end)
end
function M.identity(a)
return a
end
-- Returns a function returning the given value
---@param a any
---@return fun():any
function M.constant(a)
return function()
return a
end
end
-- Returns a function that returns the given value if it is a function,
-- otherwise returns a function that returns the given value.
---@param a any
---@return fun(...):any
function M.to_func(a)
return type(a) == "function" and a or M.constant(a)
end
---@return string|nil
function M.ts_cli_version()
if fn.executable "tree-sitter" == 1 then
local handle = io.popen "tree-sitter -V"
if not handle then
return
end
local result = handle:read "*a"
handle:close()
return vim.split(result, "\n")[1]:match "[^tree%psitter ].*"
end
function M.get_package_path(...)
return M.join_path(vim.fn.fnamemodify(debug.getinfo(1, 'S').source:sub(2), ':p:h:h:h'), ...)
end
return M