chore(textobject): use query.find_best_match to find next/previous textobject

This commit is contained in:
Stephan Seitz 2020-08-03 00:25:12 +02:00
parent e629efafd8
commit f6681c230f
2 changed files with 60 additions and 40 deletions

View file

@ -212,6 +212,36 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
return matches
end
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
if not string.sub(capture_string, 1,2) == '@' then
api.nvim_err_writeln('capture_string must start with "@"')
return
end
--remove leading "@"
capture_string = string.sub(capture_string, 2)
local best
local best_score
for maybe_match in M.iter_group_results(bufnr, query_group) do
local match = utils.get_at_path(maybe_match, capture_string)
if match and filter_predicate(match) then
local current_score = scoring_function(match)
if not best then
best = match
best_score = current_score
end
if current_score > best_score then
best = match
best_score = current_score
end
end
end
return best
end
-- Iterates matches from a query file.
-- @param bufnr the buffer
-- @param query_group the query file to use

View file

@ -145,56 +145,46 @@ end
function M.next_textobject(node, query_string, same_parent, bufnr)
local node = node or ts_utils.get_node_at_cursor()
if not node then return end
local _, _, node_end = node:end_()
local bufnr = bufnr or api.nvim_get_current_buf()
local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects')
local _, _ , node_end = node:end_()
local next_node
local next_node_start
for _, m in pairs(matches) do
local _, _, other_end = m.node:start()
if other_end > node_end then
if not same_parent or node:parent() == m.node:parent() then
if not next_node then
next_node = m
_, _, next_node_start = next_node.node:start()
end
if other_end < next_node_start then
next_node = m
_, _, next_node_start = next_node.node:start()
end
end
end
end
local next_node = queries.find_best_match(bufnr,
query_string,
'textobjects',
function(match)
if not same_parent or node:parent() == match.node:parent() then
local _, _, start = match.node:start()
return start > node_end
end
end,
function(match)
local _, _, node_start = match.node:start()
return -node_start
end)
return next_node and next_node.node
end
function M.previous_textobject(node, query_string, same_parent, bufnr)
local node = node or ts_utils.get_node_at_cursor()
if not node then return end
local _, _, node_start = node:start()
local bufnr = bufnr or api.nvim_get_current_buf()
local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects')
local _, _ , node_start = node:start()
local previous_node
local previous_node_end
for _, m in pairs(matches) do
local _, _, other_end = m.node:end_()
if other_end < node_start then
if not same_parent or node:parent() == m.node:parent() then
if not previous_node then
previous_node = m
_, _, previous_node_end = previous_node.node:end_()
end
if other_end > previous_node_end then
previous_node = m
_, _, previous_node_end = previous_node.node:end_()
end
end
end
end
local previous_node = queries.find_best_match(bufnr,
query_string,
'textobjects',
function(match)
if not same_parent or node:parent() == match.node:parent() then
local _, _, end_ = match.node:end_()
return end_ < node_start
end
end,
function(match)
local _, _, node_end = match.node:end_()
return node_end
end)
return previous_node and previous_node.node
end