From cc0e432247c0b6cf8bd529810841b030ce7ce237 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Fri, 13 Sep 2024 22:17:24 +0300 Subject: [PATCH] refactor logic --- pr_agent/algo/git_patch_processing.py | 124 +++++++++++++++----------- pr_agent/algo/pr_processing.py | 4 +- pr_agent/tools/pr_reviewer.py | 2 +- tests/unittest/test_extend_patch.py | 11 +++ 4 files changed, 83 insertions(+), 58 deletions(-) diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 0a21875c6..4d85f2987 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import traceback from pr_agent.config_loader import get_settings from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo @@ -12,27 +13,48 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str: return patch_str - if type(original_file_str) == bytes: + original_file_str = decode_if_bytes(original_file_str) + if not original_file_str: + return patch_str + + if should_skip_patch(filename): + return patch_str + + try: + extended_patch_str = process_patch_lines(patch_str, original_file_str, + patch_extra_lines_before, patch_extra_lines_after) + except Exception as e: + get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) + return patch_str + + return extended_patch_str + + +def decode_if_bytes(original_file_str): + if isinstance(original_file_str, bytes): try: - original_file_str = original_file_str.decode('utf-8') + return original_file_str.decode('utf-8') except UnicodeDecodeError: + encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16'] + for encoding in encodings_to_try: + try: + return original_file_str.decode(encoding) + except UnicodeDecodeError: + continue return "" + return original_file_str - # skip patches - patch_extension_skip_types = get_settings().config.patch_extension_skip_types #[".md",".txt"] + +def should_skip_patch(filename): + patch_extension_skip_types = get_settings().config.patch_extension_skip_types if patch_extension_skip_types and filename: - if any([filename.endswith(skip_type) for skip_type in patch_extension_skip_types]): - return patch_str + return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types) + return False + - # dynamic context settings +def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after): allow_dynamic_context = get_settings().config.allow_dynamic_context - max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context - patch_extra_lines_before_dynamic = patch_extra_lines_before - if allow_dynamic_context: - if max_extra_lines_before_dynamic_context > patch_extra_lines_before: - patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context - else: - get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'") + patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context original_lines = original_file_str.splitlines() len_original_lines = len(original_lines) @@ -46,23 +68,14 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, for line in patch_lines: if line.startswith('@@'): match = RE_HUNK_HEADER.match(line) + # identify hunk header if match: - # finish last hunk + # finish processing previous hunk if start1 != -1 and patch_extra_lines_after > 0: - delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] - delta_lines = [f' {line}' for line in delta_lines] + delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]] extended_patch_lines.extend(delta_lines) - res = list(match.groups()) - for i in range(len(res)): - if res[i] is None: - res[i] = 0 - try: - start1, size1, start2, size2 = map(int, res[:4]) - except: # '@@ -0,0 +1 @@' case - start1, size1, size2 = map(int, res[:3]) - start2 = 0 - section_header = res[4] + section_header, size1, size2, start1, start2 = extract_hunk_headers(match) if patch_extra_lines_before > 0 or patch_extra_lines_after > 0: def _calc_context_limits(patch_lines_before): @@ -82,7 +95,7 @@ def _calc_context_limits(patch_lines_before): _calc_context_limits(patch_extra_lines_before_dynamic) lines_before = original_lines[extended_start1 - 1:start1 - 1] found_header = False - for i,line, in enumerate(lines_before): + for i, line, in enumerate(lines_before): if section_header in line: found_header = True # Update start and size in one line each @@ -99,12 +112,13 @@ def _calc_context_limits(patch_lines_before): extended_start1, extended_size1, extended_start2, extended_size2 = \ _calc_context_limits(patch_extra_lines_before) - delta_lines = original_lines[extended_start1 - 1:start1 - 1] - delta_lines = [f' {line}' for line in delta_lines] + delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]] + + # logic to remove section header if its in the extra delta lines (in dynamic context, this is also done) if section_header and not allow_dynamic_context: for line in delta_lines: if section_header in line: - section_header = '' # remove section header if it is in the extra delta lines + section_header = '' # remove section header if it is in the extra delta lines break else: extended_start1 = start1 @@ -120,11 +134,10 @@ def _calc_context_limits(patch_lines_before): continue extended_patch_lines.append(line) except Exception as e: - if get_settings().config.verbosity_level >= 2: - get_logger().error(f"Failed to extend patch: {e}") + get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) return patch_str - # finish last hunk + # finish processing last hunk if start1 != -1 and patch_extra_lines_after > 0: delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] # add space at the beginning of each extra line @@ -135,6 +148,20 @@ def _calc_context_limits(patch_lines_before): return extended_patch_str +def extract_hunk_headers(match): + res = list(match.groups()) + for i in range(len(res)): + if res[i] is None: + res[i] = 0 + try: + start1, size1, start2, size2 = map(int, res[:4]) + except: # '@@ -0,0 +1 @@' case + start1, size1, size2 = map(int, res[:3]) + start2 = 0 + section_header = res[4] + return section_header, size1, size2, start1, start2 + + def omit_deletion_hunks(patch_lines) -> str: """ Omit deletion hunks from the patch and return the modified patch. @@ -253,8 +280,8 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: start1, size1, start2, size2 = -1, -1, -1, -1 prev_header_line = [] header_line = [] - for line in patch_lines: - if 'no newline at end of file' in line.lower(): + for line_i, line in enumerate(patch_lines): + if 'no newline at end of file' in line.lower().strip().strip('//'): continue if line.startswith('@@'): @@ -280,21 +307,18 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: if match: prev_header_line = header_line - res = list(match.groups()) - for i in range(len(res)): - if res[i] is None: - res[i] = 0 - try: - start1, size1, start2, size2 = map(int, res[:4]) - except: # '@@ -0,0 +1 @@' case - start1, size1, size2 = map(int, res[:3]) - start2 = 0 + section_header, size1, size2, start1, start2 = extract_hunk_headers(match) elif line.startswith('+'): new_content_lines.append(line) elif line.startswith('-'): old_content_lines.append(line) else: + if not line and line_i: # if this line is empty and the next line is a hunk header, skip it + if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'): + continue + elif line_i + 1 == len(patch_lines): + continue new_content_lines.append(line) old_content_lines.append(line) @@ -339,15 +363,7 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s match = RE_HUNK_HEADER.match(line) - res = list(match.groups()) - for i in range(len(res)): - if res[i] is None: - res[i] = 0 - try: - start1, size1, start2, size2 = map(int, res[:4]) - except: # '@@ -0,0 +1 @@' case - start1, size1, size2 = map(int, res[:3]) - start2 = 0 + section_header, size1, size2, start1, start2 = extract_hunk_headers(match) # check if line range is in this hunk if side.lower() == 'left': diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 02c416d0f..95d2fda72 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -347,11 +347,9 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT except: get_logger().warning( f"Failed to generate prediction with {model}" - f"{(' from deployment ' + deployment_id) if deployment_id else ''}: " - f"{traceback.format_exc()}" ) if i == len(all_models) - 1: # If it's the last iteration - raise # Re-raise the last exception + raise Exception(f"Failed to generate prediction with any model of {all_models}") def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 8000450f5..88799d987 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -164,7 +164,7 @@ async def _prepare_prediction(self, model: str) -> None: self.token_handler, model, add_line_numbers_to_hunks=True, - disable_extra_lines=True,) + disable_extra_lines=False,) if self.patches_diff: get_logger().debug(f"PR diff", diff=self.patches_diff) diff --git a/tests/unittest/test_extend_patch.py b/tests/unittest/test_extend_patch.py index 03fb5ad9a..2d8913f39 100644 --- a/tests/unittest/test_extend_patch.py +++ b/tests/unittest/test_extend_patch.py @@ -60,11 +60,22 @@ def test_multiple_hunks(self): original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6' patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501 num_lines = 1 + original_allow_dynamic_context = get_settings().config.allow_dynamic_context + + get_settings().config.allow_dynamic_context = False expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501 actual_output = extend_patch(original_file_str, patch_str, patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output + get_settings().config.allow_dynamic_context = True + expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501 + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) + assert actual_output == expected_output + get_settings().config.allow_dynamic_context = original_allow_dynamic_context + + def test_dynamic_context(self): get_settings().config.max_extra_lines_before_dynamic_context = 10 original_file_str = "def foo():"