Skip to content

Commit

Permalink
Merge pull request #447 from Codium-ai/tr/pydantic
Browse files Browse the repository at this point in the history
Refactor PR label handling and update CLI commands
  • Loading branch information
mrT23 authored Nov 26, 2023
2 parents 46d4d04 + 668041c commit d4e979c
Show file tree
Hide file tree
Showing 22 changed files with 308 additions and 215 deletions.
5 changes: 3 additions & 2 deletions pr_agent/algo/language_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from pr_agent.config_loader import get_settings

language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}


# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = get_settings().bad_extensions.default
Expand All @@ -29,6 +28,8 @@ def sort_files_by_main_languages(languages: Dict, files: list):
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages
main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language in languages_sorted_list:
if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()])
Expand Down
31 changes: 1 addition & 30 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
Expand Down Expand Up @@ -326,35 +326,6 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
return position, absolute_position


def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text

try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text


def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
Expand Down
45 changes: 40 additions & 5 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starlette_context import context

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

Expand Down Expand Up @@ -338,12 +339,15 @@ def set_custom_labels(variables):
labels_list = f" - {labels_list}" if labels_list else ""
variables["custom_labels"] = labels_list
return
final_labels = ""
#final_labels = ""
#for k, v in labels.items():
# final_labels += f" - {k} ({v['description']})\n"
#variables["custom_labels"] = final_labels
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
variables["custom_labels_class"] = "class Label(str, Enum):"
for k, v in labels.items():
final_labels += f" - {k} ({v['description']})\n"
variables["custom_labels"] = final_labels
variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"

description = v['description'].strip('\n').replace('\n', '\\n')
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"

def get_user_labels(current_labels: List[str] = None):
"""
Expand Down Expand Up @@ -375,3 +379,34 @@ def get_max_tokens(model):
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model


def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped
Returns:
str: The clipped string.
"""
if not text:
return text

try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
if add_three_dots:
clipped_text += "...(truncated)"
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text
16 changes: 10 additions & 6 deletions pr_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,22 @@ def run(inargs=None):
- cli.py --issue_url=... similar_issue
Supported commands:
-review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
- review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
-ask / ask_question [question] - Ask a question about the PR.
- ask / ask_question [question] - Ask a question about the PR.
-describe / describe_pr - Modify the PR title and description based on the PR's contents.
- describe / describe_pr - Modify the PR title and description based on the PR's contents.
-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
- improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
-reflect - Ask the PR author questions about the PR.
- reflect - Ask the PR author questions about the PR.
-update_changelog - Update the changelog based on the PR's contents.
- update_changelog - Update the changelog based on the PR's contents.
- add_docs
- generate_labels
Configuration:
Expand Down
3 changes: 1 addition & 2 deletions pr_agent/git_providers/azuredevops_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
except ImportError:
AZURE_DEVOPS_AVAILABLE = False

from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings
from ..algo.utils import load_large_diff
from ..algo.utils import load_large_diff, clip_tokens
from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo

Expand Down
4 changes: 3 additions & 1 deletion pr_agent/git_providers/codecommit_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from pr_agent.git_providers.codecommit_client import CodeCommitClient

from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.utils import load_large_diff
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..config_loader import get_settings
from ..log import get_logger


Expand Down Expand Up @@ -269,6 +269,8 @@ def get_languages(self):
# where each dictionary item is a language name.
# We build that language->extension dictionary here in main_extensions_flat.
main_extensions_flat = {}
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language, extensions in language_extension_map.items():
for ext in extensions:
main_extensions_flat[ext] = language
Expand Down
59 changes: 38 additions & 21 deletions pr_agent/git_providers/git_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from typing import Optional

from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger


Expand Down Expand Up @@ -62,7 +63,7 @@ def get_pr_description_full(self) -> str:

def get_pr_description(self, *, full: bool = True) -> str:
from pr_agent.config_loader import get_settings
from pr_agent.algo.pr_processing import clip_tokens
from pr_agent.algo.utils import clip_tokens
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() if full else self.get_user_description()
if max_tokens_description:
Expand Down Expand Up @@ -173,26 +174,42 @@ def get_main_pr_language(languages, files) -> str:
extension_list.append(file.filename.rsplit('.')[-1])

# get the most common extension
most_common_extension = max(set(extension_list), key=extension_list.count)

# look for a match. TBD: add more languages, do this systematically
if most_common_extension == 'py' and top_language == 'python' or \
most_common_extension == 'js' and top_language == 'javascript' or \
most_common_extension == 'ts' and top_language == 'typescript' or \
most_common_extension == 'go' and top_language == 'go' or \
most_common_extension == 'java' and top_language == 'java' or \
most_common_extension == 'c' and top_language == 'c' or \
most_common_extension == 'cpp' and top_language == 'c++' or \
most_common_extension == 'cs' and top_language == 'c#' or \
most_common_extension == 'swift' and top_language == 'swift' or \
most_common_extension == 'php' and top_language == 'php' or \
most_common_extension == 'rb' and top_language == 'ruby' or \
most_common_extension == 'rs' and top_language == 'rust' or \
most_common_extension == 'scala' and top_language == 'scala' or \
most_common_extension == 'kt' and top_language == 'kotlin' or \
most_common_extension == 'pl' and top_language == 'perl' or \
most_common_extension == top_language:
main_language_str = top_language
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
try:
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}

if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
main_language_str = top_language
else:
for language, extensions in language_extension_map.items():
if most_common_extension in extensions:
main_language_str = language
break
except Exception as e:
get_logger().exception(f"Failed to get main language: {e}")
pass

## old approach:
# most_common_extension = max(set(extension_list), key=extension_list.count)
# if most_common_extension == 'py' and top_language == 'python' or \
# most_common_extension == 'js' and top_language == 'javascript' or \
# most_common_extension == 'ts' and top_language == 'typescript' or \
# most_common_extension == 'tsx' and top_language == 'typescript' or \
# most_common_extension == 'go' and top_language == 'go' or \
# most_common_extension == 'java' and top_language == 'java' or \
# most_common_extension == 'c' and top_language == 'c' or \
# most_common_extension == 'cpp' and top_language == 'c++' or \
# most_common_extension == 'cs' and top_language == 'c#' or \
# most_common_extension == 'swift' and top_language == 'swift' or \
# most_common_extension == 'php' and top_language == 'php' or \
# most_common_extension == 'rb' and top_language == 'ruby' or \
# most_common_extension == 'rs' and top_language == 'rust' or \
# most_common_extension == 'scala' and top_language == 'scala' or \
# most_common_extension == 'kt' and top_language == 'kotlin' or \
# most_common_extension == 'pl' and top_language == 'perl' or \
# most_common_extension == top_language:
# main_language_str = top_language

except Exception as e:
get_logger().exception(e)
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from starlette_context import context

from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff, clip_tokens
from ..config_loader import get_settings
from ..log import get_logger
from ..servers.utils import RateLimitExceeded
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/git_providers/gitlab_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from gitlab import GitlabGetError

from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens, find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
from ..algo.utils import load_large_diff, clip_tokens
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..log import get_logger
Expand Down
14 changes: 7 additions & 7 deletions pr_agent/settings/custom_labels.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ enable_custom_labels=false

## template for custom labels
#[custom_labels."Bug fix"]
#description = "Fixes a bug in the code"
#description = """Fixes a bug in the code"""
#[custom_labels."Tests"]
#description = "Adds or modifies tests"
#description = """Adds or modifies tests"""
#[custom_labels."Bug fix with tests"]
#description = "Fixes a bug in the code and adds or modifies tests"
#description = """Fixes a bug in the code and adds or modifies tests"""
#[custom_labels."Refactoring"]
#description = "Code refactoring without changing functionality"
#description = """Code refactoring without changing functionality"""
#[custom_labels."Enhancement"]
#description = "Adds new features or functionality"
#description = """Adds new features or functionality"""
#[custom_labels."Documentation"]
#description = "Adds or modifies documentation"
#description = """Adds or modifies documentation"""
#[custom_labels."Other"]
#description = "Other changes that do not fit in any of the above categories"
#description = """Other changes that do not fit in any of the above categories"""
4 changes: 2 additions & 2 deletions pr_agent/settings/pr_add_docs.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pr_add_docs_prompt]
system="""You are a language model called PR-Code-Documentation Agent, that specializes in generating documentation for code.
Your task is to generate meaningfull {{ docs_for_language }} to a PR (the '+' lines).
Your task is to generate meaningfull {{ docs_for_language }} to a PR (lines starting with '+').
Example for a PR Diff input:
'
Expand Down Expand Up @@ -103,7 +103,7 @@ Description: '{{description}}'
{%- if language %}
Main language: {{language}}
Main PR language: '{{language}}'
{%- endif %}
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/settings/pr_code_suggestions_prompts.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pr_code_suggestions_prompt]
system="""You are a language model called PR-Code-Reviewer, that specializes in suggesting code improvements for Pull Request (PR).
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR (the '+' lines in the diff).
system="""You are PR-Reviewer, a language model that specializes in suggesting code improvements for a Pull Request (PR).
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR diff (lines starting with '+').
Example for a PR Diff input:
'
Expand Down Expand Up @@ -120,7 +120,7 @@ Description: '{{description}}'
{%- if language %}
Main language: {{language}}
Main PR language: '{{ language }}'
{%- endif %}
Expand Down
Loading

0 comments on commit d4e979c

Please sign in to comment.