From ed65493718d2e307d58f332da9e0543bd4e1bda2 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 11 Aug 2024 12:08:00 +0300 Subject: [PATCH] Handle edge cases for patch extension and update tests --- pr_agent/algo/git_patch_processing.py | 13 +++++++++---- tests/unittest/test_extend_patch.py | 11 ++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 5cb18b3a0..ba98e54d5 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -18,6 +18,7 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch return "" original_lines = original_file_str.splitlines() + len_original_lines = len(original_lines) patch_lines = patch_str.splitlines() extended_patch_lines = [] @@ -29,8 +30,8 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch if line.startswith('@@'): match = RE_HUNK_HEADER.match(line) if match: - # finish previous hunk - if start1 != -1: + # finish last hunk + if start1 != -1 and patch_extra_lines_after > 0: extended_patch_lines.extend( original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]) @@ -46,8 +47,12 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch section_header = res[4] extended_start1 = max(1, start1 - patch_extra_lines_before) extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after + if extended_start1 - 1 + extended_size1 > len(original_lines): + extended_size1 = len_original_lines - extended_start1 + 1 extended_start2 = max(1, start2 - patch_extra_lines_before) extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after + if extended_start2 - 1 + extended_size2 > len_original_lines: + extended_size2 = len_original_lines - extended_start2 + 1 extended_patch_lines.append( f'@@ -{extended_start1},{extended_size1} ' f'+{extended_start2},{extended_size2} @@ {section_header}') @@ -60,8 +65,8 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch get_logger().error(f"Failed to extend patch: {e}") return patch_str - # finish previous hunk - if start1 != -1: + # finish last hunk + if start1 != -1 and patch_extra_lines_after > 0: extended_patch_lines.extend( original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]) diff --git a/tests/unittest/test_extend_patch.py b/tests/unittest/test_extend_patch.py index 7737ee8d1..f44d74179 100644 --- a/tests/unittest/test_extend_patch.py +++ b/tests/unittest/test_extend_patch.py @@ -44,11 +44,12 @@ def test_no_hunks(self): def test_single_hunk(self): original_file_str = 'line1\nline2\nline3\nline4\nline5' patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4' - num_lines = 1 - expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5' - actual_output = extend_patch(original_file_str, patch_str, - patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) - assert actual_output == expected_output + + for num_lines in [1, 2, 3]: # check that even if we are over the number of lines in the file, the function still works + expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5' + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) + assert actual_output == expected_output # Tests the functionality of extending a patch with multiple hunks. def test_multiple_hunks(self):