diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index cd38bee53..006345890 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -94,13 +94,15 @@ local function compile_regex(regex) end end -local function strip_beginning(lines, regex) +local function strip_beginning(lines, regex, continuation_line_regex) local current_col = 0 local current_line = 0 - local re = compile_regex('^('..regex..')') + local re_first = compile_regex('^('..regex..')') + local re_next = compile_regex('^('..continuation_line_regex..')') for linenr, line in ipairs(lines) do current_line = linenr + local re = (linenr == 1 and re_first or re_next) local match_start, match_end = re:match_str(line) if match_start ~= 0 then break @@ -115,14 +117,17 @@ local function strip_beginning(lines, regex) return current_line - 1, current_col end -local function strip_end(lines, regex) +local function strip_end(lines, regex, continuation_line_regex) local current_col = 0 local line_diff = 0 - local re = compile_regex('('..regex..')$') - for linenr=#lines, 1, -1 do + local re_last = compile_regex('('..regex..')$') + local re_before = compile_regex('('..continuation_line_regex..')$') + for linenr = #lines, 1, -1 do local line = lines[linenr] line_diff = #lines - linenr + local re = linenr == #lines and re_last or re_before local match_start, match_end = re:match_str(line) + if match_end ~= vim.str_byteindex(line, #line) then break else @@ -185,9 +190,10 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) if pred[1] == "set!" and type(pred[2]) == "string" then insert_to_path(prepared_match, split(pred[2]), pred[3]) end - if pred[1] == "strip!" and #pred == 3 then + if pred[1] == "strip!" and #pred == 4 then local capture_name = query.captures[pred[2]] local regex = pred[3] + local continuation_line_regex = pred[4] local function process_range() local node = utils.get_at_path(prepared_match, capture_name..'.node') @@ -195,14 +201,15 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) local node_lines = ts_utils.get_node_text(node, bufnr) local start_line, start_col, end_line, _ = node:range() - local strip_line, strip_col = strip_beginning(node_lines, regex) + local strip_line, strip_col = strip_beginning(node_lines, regex, continuation_line_regex) + strip_beginning(node_lines, regex, continuation_line_regex) start_line = start_line + strip_line if strip_line == 0 then start_col = start_col + strip_col else start_col = strip_col end - local strip_line, strip_col = strip_end(node_lines, regex) + local strip_line, strip_col = strip_end(node_lines, regex, continuation_line_regex) end_line = end_line - strip_line local end_col if strip_line == 0 then