diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 8b319446c..b195f9f4a 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -11,7 +11,7 @@ from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages -from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.token_handler import TokenHandler, get_token_encoder from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider @@ -284,3 +284,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], absolute_position = start2 + delta - 1 break return position, absolute_position + + +def clip_tokens(text: str, max_tokens: int) -> str: + """ + Clip the number of tokens in a string to a maximum number of tokens. + + Args: + text (str): The string to clip. + max_tokens (int): The maximum number of tokens allowed in the string. + + Returns: + str: The clipped string. + """ + # We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word + encoder = get_token_encoder() + num_input_tokens = len(encoder.encode(text)) + if num_input_tokens <= max_tokens: + return text + num_chars = len(text) + chars_per_token = num_chars / num_input_tokens + num_output_chars = int(chars_per_token * max_tokens) + clipped_text = text[:num_output_chars] + return clipped_text diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 3686f521e..f018a92b0 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -4,6 +4,10 @@ from pr_agent.config_loader import get_settings +def get_token_encoder(): + return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding( + "cl100k_base") + class TokenHandler: """ A class for handling tokens in the context of a pull request. @@ -27,7 +31,7 @@ def __init__(self, pr, vars: dict, system, user): - system: The system string. - user: The user string. """ - self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base") + self.encoder = get_token_encoder() self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 122b0db3a..07b922957 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -5,6 +5,7 @@ import requests from atlassian.bitbucket import Cloud +from ..algo.pr_processing import clip_tokens from ..config_loader import get_settings from .git_provider import FilePatchInfo @@ -81,6 +82,9 @@ def get_pr_branch(self): return self.pr.source_branch def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.pr.description, max_tokens) return self.pr.description def get_user_id(self): diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 8e161252d..2a891938c 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -97,6 +97,10 @@ def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: pass + @abstractmethod + def get_commit_messages(self): + pass + def get_main_pr_language(languages, files) -> str: """ Get the main language of the commit. Return an empty string if cannot determine. diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index bc5cc6a75..dbad53884 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -12,7 +12,7 @@ from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff -from ..algo.pr_processing import find_line_number_of_relevant_line_in_file +from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens from ..config_loader import get_settings from ..servers.utils import RateLimitExceeded @@ -234,6 +234,9 @@ def get_pr_branch(self): return self.pr.head.ref def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.pr.body, max_tokens) return self.pr.body def get_user_id(self): @@ -375,19 +378,22 @@ def get_labels(self): logging.exception(f"Failed to get labels, error: {e}") return [] - def get_commit_messages(self) -> str: + def get_commit_messages(self): """ Retrieves the commit messages of a pull request. Returns: str: A string containing the commit messages of the pull request. """ + max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: commit_list = self.pr.get_commits() commit_messages = [commit.commit.message for commit in commit_list] commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) - except: + except Exception: commit_messages_str = "" + if max_tokens: + commit_messages_str = clip_tokens(commit_messages_str, max_tokens) return commit_messages_str def generate_link_to_relevant_line_number(self, suggestion) -> str: diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index a4d2d1274..73a3a2f92 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -7,6 +7,7 @@ from gitlab import GitlabGetError from ..algo.language_handler import is_valid_file +from ..algo.pr_processing import clip_tokens from ..algo.utils import load_large_diff from ..config_loader import get_settings from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider @@ -275,6 +276,9 @@ def get_pr_branch(self): return self.mr.source_branch def get_pr_description(self): + max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) + if max_tokens: + return clip_tokens(self.mr.description, max_tokens) return self.mr.description def get_issue_comments(self): @@ -338,16 +342,19 @@ def publish_inline_comments(self, comments: list[dict]): def get_labels(self): return self.mr.labels - def get_commit_messages(self) -> str: + def get_commit_messages(self): """ Retrieves the commit messages of a pull request. Returns: str: A string containing the commit messages of the pull request. """ + max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) - except: + except Exception: commit_messages_str = "" + if max_tokens: + commit_messages_str = clip_tokens(commit_messages_str, max_tokens) return commit_messages_str \ No newline at end of file diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 8334049da..0c502df9f 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -8,6 +8,8 @@ verbosity_level=0 # 0,1,2 use_extra_bad_extensions=false use_repo_settings_file=true ai_timeout=180 +max_description_tokens = 500 +max_commits_tokens = 500 [pr_reviewer] # /review # require_focused_review=true diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 982f50007..f679851b6 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -8,7 +8,7 @@ from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \ - find_line_number_of_relevant_line_in_file + find_line_number_of_relevant_line_in_file, clip_tokens from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.config_loader import get_settings