diff --git a/lua/nvim-treesitter/node_movement.lua b/lua/nvim-treesitter/node_movement.lua index 65e31a01b..e2fa7b8fa 100644 --- a/lua/nvim-treesitter/node_movement.lua +++ b/lua/nvim-treesitter/node_movement.lua @@ -17,12 +17,40 @@ local function node_start_to_vim(node) if not node then return end local row, col = node:start() - local exec_command = string.format('call cursor(%d, %d)', row+1, col+1) - api.nvim_exec(exec_command, false) + + local mode = api.nvim_get_mode().mode + print(vim.inspect(mode)) + --if mode == 'v' then + --local _, current_line, current_col, _ = unpack(vim.fn.getpos(".")) + --local _, sel_start_line, sel_start_col, _ = unpack(vim.fn.getpos("'<")) + --local _, sel_end_line, sel_end_col, _ = unpack(vim.fn.getpos("'>")) + + --if current_line == sel_start_line and current_col == sel_start_col then + --sel_start_line = row + 1 + --sel_start_col = col + 1 + --vim.fn.setpos("'<", {row + 1, col + 1}) + --end + --if current_line == sel_end_line and current_col == sel_end_col then + --row, col = node:end_() + --sel_end_line = row + 1 + --sel_end_col = col + --vim.fn.setpos("'>", {row + 1, col + 1}) + --end + --local exec_command = string.format(select_range, + --sel_start_line, sel_start_col, + --sel_end_line, sel_end_col) + + --api.nvim_exec(exec_command, false) + --else + api.nvim_exec('normal gv', false) + api.nvim_exec('normal o', false) + api.nvim_win_set_cursor(0, {row + 1, col}) + --end end -M.do_node_movement = function(kind) - local buf, line, col = unpack(vim.fn.getpos(".")) +M.do_node_movement = function(kind, move_node) + local line, col = unpack(api.nvim_win_get_cursor(0)) + local buf = api.nvim_win_get_buf(0) local current_node = M.current_node[buf] @@ -40,7 +68,6 @@ M.do_node_movement = function(kind) current_node = root:named_descendant_for_range(line-1, col-1, line-1, col) end - --UP if kind == M.NodeMovementKind.up then destination_node = current_node:parent() @@ -66,6 +93,20 @@ M.do_node_movement = function(kind) if destination_node then node_start_to_vim(destination_node) + if move_node then + if kind ~= M.NodeMovementKind.down then + local _, new_destination_range = utils.swap_nodes(buf, current_node, destination_node) + + local root = parsers.get_parser():parse():root() + if new_destination_range then + local new_destination_node = root:named_descendant_for_range(new_destination_range[0], + new_destination_range[1], + new_destination_range[2], + new_destination_range[3]) + M.current_node[buf] = new_destination_node or current_node + end + end + end end end @@ -74,6 +115,11 @@ M.move_down = function() M.do_node_movement(M.NodeMovementKind.down) end M.move_left = function() M.do_node_movement(M.NodeMovementKind.left) end M.move_right = function() M.do_node_movement(M.NodeMovementKind.right) end +M.node_move_up = function() M.do_node_movement(M.NodeMovementKind.up, true) end +M.node_move_down = function() M.do_node_movement(M.NodeMovementKind.down, true) end +M.node_move_left = function() M.do_node_movement(M.NodeMovementKind.left, true) end +M.node_move_right = function() M.do_node_movement(M.NodeMovementKind.right, true) end + function M.attach(bufnr) local buf = bufnr or api.nvim_get_current_buf() diff --git a/lua/nvim-treesitter/utils.lua b/lua/nvim-treesitter/utils.lua index c74512633..f2d0bf290 100644 --- a/lua/nvim-treesitter/utils.lua +++ b/lua/nvim-treesitter/utils.lua @@ -64,6 +64,95 @@ function M.is_parent(dest, source) return false end +function M.string_to_lines(str) + local t={} + local line_breaks = 0 + for line in str:gmatch("(.-[\n(\r\n)\r])") do + table.insert(t, line) + line_breaks = line_breaks + 1 + end + if #t == 0 then + return {str}, 0 + --for line in str:gmatch("[^\n(\r\n)\r]+$") do + --table.insert(t, line) + end + + return t, line_breaks +end + +function M.replace_node(buf, source, destination) + local replacement_lines = M.get_node_text(source) + return M.replace_node_text(buf, destination, replacement_lines) +end + +function M.replace_node_text(buf, node_or_range, replacement_lines) + local start_row, start_col, end_row, end_col + if type(node_or_range) == 'table' then + start_row, start_col, end_row, end_col = unpack(node_or_range) + else + start_row, start_col, end_row, end_col = node_or_range:range() + end + + local original_lines = api.nvim_buf_get_lines(buf, start_row, end_row + 1, false) + -- original_lines[1]..'' <- Empty string is necessary! Bug in vim string to lua string conversion?? + local new_text = string.sub(original_lines[1]..'', 1, start_col)..table.concat(replacement_lines, '')..string.sub(original_lines[#original_lines], end_col + 1) + + local new_lines, line_count = M.string_to_lines(new_text) + + api.nvim_buf_set_lines(buf, start_row, end_row + 1, false, new_lines) + + return {start_row, + start_col, + start_row + line_count, + (line_count == 0 and start_col or 0) + #replacement_lines[#replacement_lines]} +end + +function M.node_lenght(node) + local start_row, _, end_row, end_col = node:range() + return end_row - start_row, end_col +end + +function M.range_difference(node1, node2) + local rows1, cols1 = M.node_lenght(node1) + local rows2, cols2 = M.node_lenght(node2) + + return rows1 - rows2, (node2:end_() == node1:end_() and cols1 - cols2 or 0) +end + +function M.swap_nodes(buf, source, destination) + local dst_start_row, dst_start_col, dst_start = destination:start() + local dst_end_row, dst_end_col, dst_end = destination:end_() + local src_start_row, src_start_col, src_start = source:start() + local src_end_row, src_end_col, src_end = source:end_() + + if dst_start <= src_start and dst_end >= src_end then + local src_range = M.replace_node(buf, source, destination) + return src_range, nil + end + + local source_text = M.get_node_text(source) + local destination_text = M.get_node_text(destination) + + if dst_end < src_start then + local diff_rows, diff_cols = M.range_difference(source, destination) + local dst_range = M.replace_node(buf, source, destination) + --local src_range = M.replace_node_text(buf, {src_start_row + diff_rows, + --src_start_col + (source:start() == destination:end_() and diff_cols or 0), + --src_end_row + diff_rows, + --(source:end_() == destination:end_() and diff_cols or 0)}, destination_text) + return src_range, dst_range + elseif src_end < dst_start then + local diff_rows, diff_cols = M.range_difference(destination, source) + local src_range = M.replace_node(buf, destination, source) + --local dst_range = M.replace_node_text(buf, {dst_start_row + diff_rows, + --dst_start_col + (destination:start() == source:end_() and diff_cols or 0), + --dst_end_row + diff_rows, + --(destination:end_() == source:end_() and diff_cols or 0)}, source_text) + return src_range, dst_range + end + +end + function M.setup_commands(mod, commands) for command_name, def in pairs(commands) do local call_fn = string.format("lua require'nvim-treesitter.%s'.commands.%s.run()", mod, command_name) @@ -98,7 +187,7 @@ end -- @param allow_switch_parents allow switching parents if last node -- @param allow_next_parent allow next parent if last node and next parent without children function M.get_next_node(node, allow_switch_parents, allow_next_parent) - local destination_node + local destination_node local parent = node:parent() if parent then