Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring PR label handling, enhancing language detection, and updating CLI commands #447

Merged
merged 18 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pr_agent/algo/language_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from pr_agent.config_loader import get_settings

language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}


# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = get_settings().bad_extensions.default
Expand All @@ -29,6 +28,8 @@ def sort_files_by_main_languages(languages: Dict, files: list):
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages
main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language in languages_sorted_list:
if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()])
Expand Down
31 changes: 1 addition & 30 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
Expand Down Expand Up @@ -326,35 +326,6 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
return position, absolute_position


def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.

Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.

Returns:
str: The clipped string.
"""
if not text:
return text

try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text


def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
Expand Down
45 changes: 40 additions & 5 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starlette_context import context

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

Expand Down Expand Up @@ -338,12 +339,15 @@ def set_custom_labels(variables):
labels_list = f" - {labels_list}" if labels_list else ""
variables["custom_labels"] = labels_list
return
final_labels = ""
#final_labels = ""
#for k, v in labels.items():
# final_labels += f" - {k} ({v['description']})\n"
#variables["custom_labels"] = final_labels
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
variables["custom_labels_class"] = "class Label(str, Enum):"
for k, v in labels.items():
final_labels += f" - {k} ({v['description']})\n"
variables["custom_labels"] = final_labels
variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"

description = v['description'].strip('\n').replace('\n', '\\n')
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
mrT23 marked this conversation as resolved.
Show resolved Hide resolved

def get_user_labels(current_labels: List[str] = None):
"""
Expand Down Expand Up @@ -375,3 +379,34 @@ def get_max_tokens(model):
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model


def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.

Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped
Returns:
str: The clipped string.
"""
if not text:
return text

try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
if add_three_dots:
clipped_text += "...(truncated)"
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text
16 changes: 10 additions & 6 deletions pr_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,22 @@ def run(inargs=None):
- cli.py --issue_url=... similar_issue

Supported commands:
-review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
- review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.

-ask / ask_question [question] - Ask a question about the PR.
- ask / ask_question [question] - Ask a question about the PR.

-describe / describe_pr - Modify the PR title and description based on the PR's contents.
- describe / describe_pr - Modify the PR title and description based on the PR's contents.

-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
- improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback

-reflect - Ask the PR author questions about the PR.
- reflect - Ask the PR author questions about the PR.

-update_changelog - Update the changelog based on the PR's contents.
- update_changelog - Update the changelog based on the PR's contents.

- add_docs

- generate_labels
mrT23 marked this conversation as resolved.
Show resolved Hide resolved


Configuration:
Expand Down
3 changes: 1 addition & 2 deletions pr_agent/git_providers/azuredevops_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
except ImportError:
AZURE_DEVOPS_AVAILABLE = False

from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings
from ..algo.utils import load_large_diff
from ..algo.utils import load_large_diff, clip_tokens
from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo

Expand Down
4 changes: 3 additions & 1 deletion pr_agent/git_providers/codecommit_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from pr_agent.git_providers.codecommit_client import CodeCommitClient

from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.utils import load_large_diff
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..config_loader import get_settings
from ..log import get_logger


Expand Down Expand Up @@ -269,6 +269,8 @@ def get_languages(self):
# where each dictionary item is a language name.
# We build that language->extension dictionary here in main_extensions_flat.
main_extensions_flat = {}
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language, extensions in language_extension_map.items():
for ext in extensions:
main_extensions_flat[ext] = language
Expand Down
59 changes: 38 additions & 21 deletions pr_agent/git_providers/git_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from typing import Optional

from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger


Expand Down Expand Up @@ -62,7 +63,7 @@ def get_pr_description_full(self) -> str:

def get_pr_description(self, *, full: bool = True) -> str:
from pr_agent.config_loader import get_settings
from pr_agent.algo.pr_processing import clip_tokens
from pr_agent.algo.utils import clip_tokens
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() if full else self.get_user_description()
if max_tokens_description:
Expand Down Expand Up @@ -173,26 +174,42 @@ def get_main_pr_language(languages, files) -> str:
extension_list.append(file.filename.rsplit('.')[-1])

# get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count)

# look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == top_language:
main_language_str = top_language
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
try:
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}

if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
main_language_str = top_language
else:
for language, extensions in language_extension_map.items():
if most_common_extension in extensions:
main_language_str = language
break
except Exception as e:
get_logger().exception(f"Failed to get main language: {e}")
pass

## old approach:
# most_common_extension = max(set(extension_list), key=extension_list.count)
# if most_common_extension == 'py' and top_language == 'python' or \
# most_common_extension == 'js' and top_language == 'javascript' or \
# most_common_extension == 'ts' and top_language == 'typescript' or \
# most_common_extension == 'tsx' and top_language == 'typescript' or \
# most_common_extension == 'go' and top_language == 'go' or \
# most_common_extension == 'java' and top_language == 'java' or \
# most_common_extension == 'c' and top_language == 'c' or \
# most_common_extension == 'cpp' and top_language == 'c++' or \
# most_common_extension == 'cs' and top_language == 'c#' or \
# most_common_extension == 'swift' and top_language == 'swift' or \
# most_common_extension == 'php' and top_language == 'php' or \
# most_common_extension == 'rb' and top_language == 'ruby' or \
# most_common_extension == 'rs' and top_language == 'rust' or \
# most_common_extension == 'scala' and top_language == 'scala' or \
# most_common_extension == 'kt' and top_language == 'kotlin' or \
# most_common_extension == 'pl' and top_language == 'perl' or \
# most_common_extension == top_language:
# main_language_str = top_language

except Exception as e:
get_logger().exception(e)
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from starlette_context import context

from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff, clip_tokens
from ..config_loader import get_settings
from ..log import get_logger
from ..servers.utils import RateLimitExceeded
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/git_providers/gitlab_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from gitlab import GitlabGetError

from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff, clip_tokens
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..log import get_logger
Expand Down
14 changes: 7 additions & 7 deletions pr_agent/settings/custom_labels.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ enable_custom_labels=false

## template for custom labels
#[custom_labels."Bug fix"]
#description = "Fixes a bug in the code"
#description = """Fixes a bug in the code"""
#[custom_labels."Tests"]
#description = "Adds or modifies tests"
#description = """Adds or modifies tests"""
#[custom_labels."Bug fix with tests"]
#description = "Fixes a bug in the code and adds or modifies tests"
#description = """Fixes a bug in the code and adds or modifies tests"""
#[custom_labels."Refactoring"]
#description = "Code refactoring without changing functionality"
#description = """Code refactoring without changing functionality"""
#[custom_labels."Enhancement"]
#description = "Adds new features or functionality"
#description = """Adds new features or functionality"""
#[custom_labels."Documentation"]
#description = "Adds or modifies documentation"
#description = """Adds or modifies documentation"""
#[custom_labels."Other"]
#description = "Other changes that do not fit in any of the above categories"
#description = """Other changes that do not fit in any of the above categories"""
4 changes: 2 additions & 2 deletions pr_agent/settings/pr_add_docs.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pr_add_docs_prompt]
system="""You are a language model called PR-Code-Documentation Agent, that specializes in generating documentation for code.
Your task is to generate meaningfull {{ docs_for_language }} to a PR (the '+' lines).
Your task is to generate meaningfull {{ docs_for_language }} to a PR (lines starting with '+').

Example for a PR Diff input:
'
Expand Down Expand Up @@ -103,7 +103,7 @@ Description: '{{description}}'

{%- if language %}

Main language: {{language}}
Main PR language: '{{language}}'
{%- endif %}


Expand Down
6 changes: 3 additions & 3 deletions pr_agent/settings/pr_code_suggestions_prompts.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pr_code_suggestions_prompt]
system="""You are a language model called PR-Code-Reviewer, that specializes in suggesting code improvements for Pull Request (PR).
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR (the '+' lines in the diff).
system="""You are PR-Reviewer, a language model that specializes in suggesting code improvements for a Pull Request (PR).
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR diff (lines starting with '+').

Example for a PR Diff input:
'
Expand Down Expand Up @@ -120,7 +120,7 @@ Description: '{{description}}'

{%- if language %}

Main language: {{language}}
Main PR language: '{{ language }}'
{%- endif %}


Expand Down
Loading
Loading