Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring AI Handlers and Adding Support for Multiple AI Models #525

Merged
merged 26 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7e47baa
Refactor AI handler classes
brianphamsia Dec 9, 2023
f2abe5c
Abstract AiHandler to BaseAiHandler
brianpham93 Dec 9, 2023
c0303ff
Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler
brianpham93 Dec 9, 2023
b640992
Remove extra code
brianpham93 Dec 9, 2023
523a896
Rename AiHandler to LiteLLMAiHandler
brianphamsia Dec 11, 2023
b8021d7
rename file
brianphamsia Dec 11, 2023
a1cbd80
update base ai handler
brianphamsia Dec 11, 2023
ebf7027
add openai handler
brianphamsia Dec 11, 2023
5239e1c
Load default AI Handler from util function
brianphamsia Dec 12, 2023
7eb2e76
Move ai handlers to specific folder
brianphamsia Dec 12, 2023
6c7becc
add LangChain AI Handler
brianphamsia Dec 12, 2023
506eafc
add langchain in requirement
brianphamsia Dec 12, 2023
0c66554
langchain: move model and temperature to chat_completion
brianphamsia Dec 12, 2023
a627dcd
Update langchain
brianphamsia Dec 12, 2023
b7225cc
update langchain
brianphamsia Dec 12, 2023
ca1ccd7
update base
brianphamsia Dec 12, 2023
8fb4a42
Update AI handler instantiation in server files
brianphamsia Dec 13, 2023
be8d6af
Add code documentation generation for PR diffs
brianphamsia Dec 13, 2023
ebb2ed8
Add logging to pr_agent.py
brianphamsia Dec 13, 2023
69a7c77
Refactor PRAgent class and has_ai_handler_param
brianphamsia Dec 13, 2023
557b39e
Merge branch 'base-ai-handler' into abstract-BaseAiHandler
brianphamsia Dec 13, 2023
e37598f
Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler
brianphamsia Dec 13, 2023
3531016
Refactor AI handler instantiation in PRAgent and related classes
mrT23 Dec 14, 2023
246be61
Set LiteLLMAIHandler as default AI handler in all PR tools and simpli…
mrT23 Dec 14, 2023
38ea914
Make LangChain dependency optional in pr-agent and update requirement…
mrT23 Dec 14, 2023
02871b1
Remove logging from pr_agent.py and add line breaks in cli.py and git…
mrT23 Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions pr_agent/agent/pr_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import shlex
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler

from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import get_logger
from pr_agent.tools.pr_add_docs import PRAddDocs
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_config import PRConfig
Expand Down Expand Up @@ -38,8 +41,8 @@
commands = list(command2class.keys())

class PRAgent:
def __init__(self):
pass
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler

async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists
Expand All @@ -61,13 +64,14 @@ async def handle_request(self, pr_url, request, notify=None) -> bool:
if action == "answer":
if notify:
notify()
await PRReviewer(pr_url, is_answer=True, args=args).run()
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args).run()
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
elif action in command2class:
if notify:
notify()
await command2class[action](pr_url, args=args).run()

await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
else:
return False
return True
Expand Down
28 changes: 28 additions & 0 deletions pr_agent/algo/ai_handlers/base_ai_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod

class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
"""

@abstractmethod
def __init__(self):
pass

@property
@abstractmethod
def deployment_id(self):
pass

@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
"""
This method should be implemented to return a chat completion from the AI model.
Args:
model (str): the name of the model to use for the chat completion
system (str): the system message string to use for the chat completion
user (str): the user message string to use for the chat completion
temperature (float): the temperature to use for the chat completion
"""
pass

49 changes: 49 additions & 0 deletions pr_agent/algo/ai_handlers/langchain_ai_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
try:
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass

from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger

from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry

OPENAI_RETRIES = 5

class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key)

except AttributeError as e:
raise ValueError("OpenAI key is required") from e

@property
def chat(self):
return self._chat

@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages=[SystemMessage(content=system), HumanMessage(content=user)]

# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed"
return resp.content, finish_reason

except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from litellm import acompletion
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger

OPENAI_RETRIES = 5


class AiHandler:
class LiteLLMAIHandler(BaseAiHandler):
"""
This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file,
Expand Down Expand Up @@ -134,4 +135,4 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
usage = response.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
return resp, finish_reason
67 changes: 67 additions & 0 deletions pr_agent/algo/ai_handlers/openai_ai_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry

from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger

OPENAI_RETRIES = 5


class OpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
openai.api_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
openai.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base

except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
deployment_id = self.deployment_id
get_logger().info("System: ", system)
get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]

chat_completion = await openai.ChatCompletion.acreate(
model=model,
deployment_id=deployment_id,
messages=messages,
temperature=temperature,
)
resp = chat_completion["choices"][0]['message']['content']
finish_reason = chat_completion["choices"][0]["finish_reason"]
usage = chat_completion.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
2 changes: 1 addition & 1 deletion pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,4 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str:
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text
return text
8 changes: 5 additions & 3 deletions pr_agent/tools/pr_add_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml
Expand All @@ -15,14 +16,15 @@


class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)

self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
8 changes: 5 additions & 3 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Dict, List
from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml
Expand All @@ -14,7 +15,8 @@


class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
Expand All @@ -31,7 +33,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None):
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions

self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
8 changes: 5 additions & 3 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
Expand All @@ -15,7 +16,8 @@


class PRDescription:
def __init__(self, pr_url: str, args: list = None):
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Expand All @@ -36,7 +38,7 @@ def __init__(self, pr_url: str, args: list = None):
get_settings().pr_description.enable_semantic_files_types = False

# Initialize the AI handler
self.ai_handler = AiHandler()
self.ai_handler = ai_handler

# Initialize the variables dictionary
self.vars = {
Expand Down
8 changes: 5 additions & 3 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
Expand All @@ -15,7 +16,8 @@


class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None):
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
Expand All @@ -31,7 +33,7 @@ def __init__(self, pr_url: str, args: list = None):
self.pr_id = self.git_provider.get_pr_id()

# Initialize the AI handler
self.ai_handler = AiHandler()
self.ai_handler = ai_handler

# Initialize the variables dictionary
self.vars = {
Expand Down
8 changes: 5 additions & 3 deletions pr_agent/tools/pr_information_from_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
Expand All @@ -12,12 +13,13 @@


class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None):
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
Expand Down
7 changes: 4 additions & 3 deletions pr_agent/tools/pr_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
Expand All @@ -12,13 +13,13 @@


class PRQuestions:
def __init__(self, pr_url: str, args=None):
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,
Expand Down
Loading
Loading