Merge pull request #241 from steelsojka/master

feat(predicates): add adjacent predicate
This commit is contained in:
Steven Sojka 2020-07-31 12:12:45 -05:00 committed by GitHub
commit e95c14c81d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,4 +1,5 @@
local utils = require'nvim-treesitter.utils'
local ts_utils = require'nvim-treesitter.ts_utils'
local M = {}
@ -12,6 +13,42 @@ local function get_node(query, match, pred_item)
return utils.get_at_path(match, query.captures[pred_item]..'.node')
end
local function create_adjacent_predicate(match_successive_nodes)
return function(query, match, pred)
if #pred < 3 then error("adjacent? must have at least two arguments!") end
local node = get_node(query, match, pred[2])
if not node then return true end
local adjacent_types = {unpack(pred, 3)}
local adjacent_node = ts_utils.get_next_node(node)
if match_successive_nodes then
-- Move to the last node in a series that doesn't match the node type
-- and use that node to compare with.
while adjacent_node and adjacent_node:type() == node:type() do
node = adjacent_node
adjacent_node = ts_utils.get_next_node(node)
end
end
if not adjacent_node then return false end
for _, adjacent_type in ipairs(adjacent_types) do
if type(adjacent_type) == "number" then
if get_node(query, match, adjacent_type) == adjacent_node then
return true
end
elseif type(adjacent_type) == "string" then
if adjacent_node:type() == adjacent_type then
return true
end
end
end
return false
end
end
function M.check_predicate(query, match, pred)
local check_function = M[pred[1]]
if check_function then
@ -55,7 +92,7 @@ end
end
end
M['has_ancestor?'] = function(query, match, pred)
M['has-ancestor?'] = function(query, match, pred)
if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") end
local node = get_node(query, match, pred[2])
local ancestor_type = pred[3]
@ -71,4 +108,7 @@ M['has_ancestor?'] = function(query, match, pred)
return false
end
M['adjacent?'] = create_adjacent_predicate(false)
M['adjacent-block?'] = create_adjacent_predicate(true)
return M