Skip to content

Commit

Permalink
Merge pull request #1227 from Codium-ai/tr/dynamic
Browse files Browse the repository at this point in the history
refactor logic
  • Loading branch information
mrT23 authored Sep 13, 2024
2 parents 0fb158f + cc0e432 commit 95d1b0d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 58 deletions.
124 changes: 70 additions & 54 deletions pr_agent/algo/git_patch_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
import traceback

from pr_agent.config_loader import get_settings
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
Expand All @@ -12,27 +13,48 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str

if type(original_file_str) == bytes:
original_file_str = decode_if_bytes(original_file_str)
if not original_file_str:
return patch_str

if should_skip_patch(filename):
return patch_str

try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str

return extended_patch_str


def decode_if_bytes(original_file_str):
if isinstance(original_file_str, bytes):
try:
original_file_str = original_file_str.decode('utf-8')
return original_file_str.decode('utf-8')
except UnicodeDecodeError:
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
for encoding in encodings_to_try:
try:
return original_file_str.decode(encoding)
except UnicodeDecodeError:
continue
return ""
return original_file_str

# skip patches
patch_extension_skip_types = get_settings().config.patch_extension_skip_types #[".md",".txt"]

def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename:
if any([filename.endswith(skip_type) for skip_type in patch_extension_skip_types]):
return patch_str
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return False


# dynamic context settings
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
allow_dynamic_context = get_settings().config.allow_dynamic_context
max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context
patch_extra_lines_before_dynamic = patch_extra_lines_before
if allow_dynamic_context:
if max_extra_lines_before_dynamic_context > patch_extra_lines_before:
patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context
else:
get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'")
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context

original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
Expand All @@ -46,23 +68,14 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
# finish last hunk
# finish processing previous hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines)

res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)

if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
def _calc_context_limits(patch_lines_before):
Expand All @@ -82,7 +95,7 @@ def _calc_context_limits(patch_lines_before):
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
found_header = False
for i,line, in enumerate(lines_before):
for i, line, in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
Expand All @@ -99,12 +112,13 @@ def _calc_context_limits(patch_lines_before):
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)

delta_lines = original_lines[extended_start1 - 1:start1 - 1]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]

# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
for line in delta_lines:
if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines
section_header = '' # remove section header if it is in the extra delta lines
break
else:
extended_start1 = start1
Expand All @@ -120,11 +134,10 @@ def _calc_context_limits(patch_lines_before):
continue
extended_patch_lines.append(line)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to extend patch: {e}")
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str

# finish last hunk
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
Expand All @@ -135,6 +148,20 @@ def _calc_context_limits(patch_lines_before):
return extended_patch_str


def extract_hunk_headers(match):
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
return section_header, size1, size2, start1, start2


def omit_deletion_hunks(patch_lines) -> str:
"""
Omit deletion hunks from the patch and return the modified patch.
Expand Down Expand Up @@ -253,8 +280,8 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line = []
for line in patch_lines:
if 'no newline at end of file' in line.lower():
for line_i, line in enumerate(patch_lines):
if 'no newline at end of file' in line.lower().strip().strip('//'):
continue

if line.startswith('@@'):
Expand All @@ -280,21 +307,18 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
if match:
prev_header_line = header_line

res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)

elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
continue
elif line_i + 1 == len(patch_lines):
continue
new_content_lines.append(line)
old_content_lines.append(line)

Expand Down Expand Up @@ -339,15 +363,7 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s

match = RE_HUNK_HEADER.match(line)

res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)

# check if line range is in this hunk
if side.lower() == 'left':
Expand Down
4 changes: 1 addition & 3 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,9 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
except:
get_logger().warning(
f"Failed to generate prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}: "
f"{traceback.format_exc()}"
)
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception
raise Exception(f"Failed to generate prediction with any model of {all_models}")


def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def _prepare_prediction(self, model: str) -> None:
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True,)
disable_extra_lines=False,)

if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff)
Expand Down
11 changes: 11 additions & 0 deletions tests/unittest/test_extend_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,22 @@ def test_multiple_hunks(self):
original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6'
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501
num_lines = 1
original_allow_dynamic_context = get_settings().config.allow_dynamic_context

get_settings().config.allow_dynamic_context = False
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output

get_settings().config.allow_dynamic_context = True
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = original_allow_dynamic_context


def test_dynamic_context(self):
get_settings().config.max_extra_lines_before_dynamic_context = 10
original_file_str = "def foo():"
Expand Down

0 comments on commit 95d1b0d

Please sign in to comment.