From 1c4e64333cdf0677b957f0eefb29dfce38c51855 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Mon, 18 Dec 2023 12:29:06 +0200 Subject: [PATCH] feat: Implement label case conversion and update label descriptions in settings files --- pr_agent/algo/utils.py | 10 ++++++++-- pr_agent/settings/pr_custom_labels.toml | 2 +- pr_agent/settings/pr_description_prompts.toml | 9 ++++++--- pr_agent/tools/pr_description.py | 11 +++++++++++ pr_agent/tools/pr_generate_labels.py | 11 +++++++++++ 5 files changed, 37 insertions(+), 6 deletions(-) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 9e1000423..303b3da53 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None): # Set custom labels variables["custom_labels_class"] = "class Label(str, Enum):" + counter = 0 + labels_minimal_to_labels_dict = {} for k, v in labels.items(): - description = v['description'].strip('\n').replace('\n', '\\n') - variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" + description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'" + # variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" + variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}" + labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k + counter += 1 + variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict def get_user_labels(current_labels: List[str] = None): """ diff --git a/pr_agent/settings/pr_custom_labels.toml b/pr_agent/settings/pr_custom_labels.toml index d9a5e0041..44b0ada89 100644 --- a/pr_agent/settings/pr_custom_labels.toml +++ b/pr_agent/settings/pr_custom_labels.toml @@ -30,7 +30,7 @@ class Label(str, Enum): {%- endif %} class Labels(BaseModel): - labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.") + labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.") ====== diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 6e0c395e6..38b95a3a3 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel): {%- endif %} {%- if enable_semantic_files_types %} + Class FileDescription(BaseModel): filename: str = Field(description="the relevant file full path") changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file") @@ -48,7 +49,7 @@ Class PRDescription(BaseModel): type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.") description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}") {%- if enable_custom_labels %} - labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.") + labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.") {%- endif %} {%- if enable_file_walkthrough %} main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10) @@ -69,8 +70,10 @@ type: - ... {%- if enable_custom_labels %} labels: -- ... -- ... +- | + ... +- | + ... {%- endif %} description: |- ... diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 4915c5b68..7d7b56782 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -162,6 +162,7 @@ async def _get_prediction(self, model: str) -> str: environment = Environment(undefined=StrictUndefined) set_custom_labels(variables, self.git_provider) + self.variables = variables system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) @@ -203,6 +204,16 @@ def _prepare_labels(self) -> List[str]: pr_types = self.data['type'] elif type(self.data['type']) == str: pr_types = self.data['type'].split(',') + + # convert lowercase labels to original case + try: + if "labels_minimal_to_labels_dict" in self.variables: + d: dict = self.variables["labels_minimal_to_labels_dict"] + for i, label_i in enumerate(pr_types): + if label_i in d: + pr_types[i] = d[label_i] + except Exception as e: + get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") return pr_types def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]: diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index 25e80a55b..28f4b8ef6 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -135,6 +135,7 @@ async def _get_prediction(self, model: str) -> str: environment = Environment(undefined=StrictUndefined) set_custom_labels(variables, self.git_provider) + self.variables = variables system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) @@ -170,4 +171,14 @@ def _prepare_labels(self) -> List[str]: elif type(self.data['labels']) == str: pr_types = self.data['labels'].split(',') + # convert lowercase labels to original case + try: + if "labels_minimal_to_labels_dict" in self.variables: + d: dict = self.variables["labels_minimal_to_labels_dict"] + for i, label_i in enumerate(pr_types): + if label_i in d: + pr_types[i] = d[label_i] + except Exception as e: + get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") + return pr_types