Skip to content

Commit

Permalink
Refactor AI handler instantiation to use lazy initialization in PR tools
Browse files Browse the repository at this point in the history
  • Loading branch information
mrT23 committed Dec 17, 2023
1 parent 54891ad commit 5fb373b
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 18 deletions.
6 changes: 4 additions & 2 deletions pr_agent/agent/pr_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import shlex
from functools import partial

from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler

Expand Down Expand Up @@ -41,8 +43,8 @@
commands = list(command2class.keys())

class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action

async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_add_docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,14 +18,14 @@

class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)

self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict, List
from jinja2 import Environment, StrictUndefined

Expand All @@ -16,7 +17,7 @@

class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
Expand All @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None,
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions

self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,7 +18,7 @@

class PRDescription:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Expand All @@ -38,7 +39,7 @@ def __init__(self, pr_url: str, args: list = None,
get_settings().pr_description.enable_semantic_files_types = False

# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()

# Initialize the variables dictionary
self.vars = {
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,7 +18,7 @@

class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
Expand All @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, args: list = None,
self.pr_id = self.git_provider.get_pr_id()

# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()

# Initialize the variables dictionary
self.vars = {
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_information_from_user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from functools import partial

from jinja2 import Environment, StrictUndefined

Expand All @@ -14,12 +15,12 @@

class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_questions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from functools import partial

from jinja2 import Environment, StrictUndefined

Expand All @@ -13,13 +14,13 @@


class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import datetime
from collections import OrderedDict
from functools import partial
from typing import List, Tuple

import yaml
Expand All @@ -24,7 +25,7 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Expand All @@ -47,7 +48,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False,

if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None

Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_update_changelog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from datetime import date
from functools import partial
from time import sleep
from typing import Tuple

Expand All @@ -18,15 +19,15 @@


class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down

0 comments on commit 5fb373b

Please sign in to comment.