diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 87cd3ae5a..798fc6c5e 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -7,4 +7,8 @@ 'gpt-4': 8000, 'gpt-4-0613': 8000, 'gpt-4-32k': 32000, + 'claude-instant-1': 100000, + 'claude-2': 100000, + 'command-nightly': 4096, + 'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index aa92f5e56..572215189 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -3,9 +3,10 @@ import openai from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry - +import litellm +from litellm import acompletion from pr_agent.config_loader import get_settings - +import traceback OPENAI_RETRIES=5 class AiHandler: @@ -22,15 +23,25 @@ def __init__(self): """ try: openai.api_key = get_settings().openai.key + litellm.openai_key = get_settings().openai.key + self.azure = False if get_settings().get("OPENAI.ORG", None): - openai.organization = get_settings().openai.org + litellm.organization = get_settings().openai.org self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) if get_settings().get("OPENAI.API_TYPE", None): - openai.api_type = get_settings().openai.api_type + if get_settings().openai.api_type == "azure": + self.azure = True + litellm.azure_key = get_settings().openai.key if get_settings().get("OPENAI.API_VERSION", None): - openai.api_version = get_settings().openai.api_version + litellm.api_version = get_settings().openai.api_version if get_settings().get("OPENAI.API_BASE", None): - openai.api_base = get_settings().openai.api_base + litellm.api_base = get_settings().openai.api_base + if get_settings().get("ANTHROPIC.KEY", None): + litellm.anthropic_key = get_settings().anthropic.key + if get_settings().get("COHERE.KEY", None): + litellm.cohere_key = get_settings().cohere.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key except AttributeError as e: raise ValueError("OpenAI key is required") from e @@ -57,7 +68,7 @@ async def chat_completion(self, model: str, temperature: float, system: str, use TryAgain: If there is an attribute error during OpenAI inference. """ try: - response = await openai.ChatCompletion.acreate( + response = await acompletion( model=model, deployment_id=self.deployment_id, messages=[ @@ -65,6 +76,7 @@ async def chat_completion(self, model: str, temperature: float, system: str, use {"role": "user", "content": user} ], temperature=temperature, + azure=self.azure ) except (APIError, Timeout, TryAgain) as e: logging.error("Error during OpenAI inference: ", e) @@ -75,8 +87,9 @@ async def chat_completion(self, model: str, temperature: float, system: str, use except (Exception) as e: logging.error("Unknown error during OpenAI inference: ", e) raise TryAgain from e - if response is None or len(response.choices) == 0: + if response is None or len(response["choices"]) == 0: raise TryAgain - resp = response.choices[0]['message']['content'] - finish_reason = response.choices[0].finish_reason + resp = response["choices"][0]['message']['content'] + finish_reason = response["choices"][0]["finish_reason"] + print(resp, finish_reason) return resp, finish_reason \ No newline at end of file diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index fb66583d5..8b319446c 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -1,9 +1,11 @@ from __future__ import annotations -import re import difflib import logging -from typing import Callable, Tuple, List, Any +import re +import traceback +from typing import Any, Callable, List, Tuple + from github import RateLimitExceededException from pr_agent.algo import MAX_TOKENS @@ -11,7 +13,7 @@ from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings -from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo +from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider DELETED_FILES_ = "Deleted files:\n" @@ -215,7 +217,7 @@ async def retry_with_fallback_models(f: Callable): try: return await f(model) except Exception as e: - logging.warning(f"Failed to generate prediction with {model}: {e}") + logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}") if i == len(all_models) - 1: # If it's the last iteration raise # Re-raise the last exception diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 0888f8b8f..3686f521e 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,5 +1,5 @@ from jinja2 import Environment, StrictUndefined -from tiktoken import encoding_for_model +from tiktoken import encoding_for_model, get_encoding from pr_agent.config_loader import get_settings @@ -27,7 +27,7 @@ def __init__(self, pr, vars: dict, system, user): - system: The system string. - user: The user string. """ - self.encoder = encoding_for_model(get_settings().config.model) + self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base") self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): @@ -47,7 +47,6 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(system).render(vars) user_prompt = environment.from_string(user).render(vars) - system_prompt_tokens = len(encoder.encode(system_prompt)) user_prompt_tokens = len(encoder.encode(user_prompt)) return system_prompt_tokens + user_prompt_tokens diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index 097cb6b4f..36b529a6e 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -7,17 +7,26 @@ # See README for details about GitHub App deployment. [openai] -key = "" # Acquire through https://platform.openai.com -org = "" # Optional, may be commented out. +key = "" # Acquire through https://platform.openai.com +#org = "" # Optional, may be commented out. # Uncomment the following for Azure OpenAI #api_type = "azure" #api_version = '2023-05-15' # Check Azure documentation for the current API version -#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://.openai.azure.com" -#deployment_id = "" # The deployment name you chose when you deployed the engine +#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://.openai.azure.com" +#deployment_id = "" # The deployment name you chose when you deployed the engine +[anthropic] +key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/ + +[cohere] +key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/ + +[replicate] +key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/ [github] # ---- Set the following only for deployment type == "user" -user_token = "" # A GitHub personal access token with 'repo' scope. +user_token = "" # A GitHub personal access token with 'repo' scope. +deployment_type = "user" #set to user by default # ---- Set the following only for deployment type == "app", see README for details. private_key = """\ diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 3f8671f8d..982f50007 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -174,7 +174,7 @@ def _prepare_pr_review(self) -> str: del pr_feedback['Security concerns'] data.setdefault('PR Analysis', {})['Security concerns'] = security_concerns - # + # if 'Code feedback' in pr_feedback: code_feedback = pr_feedback['Code feedback'] @@ -218,6 +218,9 @@ def _prepare_pr_review(self) -> str: if get_settings().config.verbosity_level >= 2: logging.info(f"Markdown response:\n{markdown_text}") + if markdown_text == None or len(markdown_text) == 0: + markdown_text = review + return markdown_text def _publish_inline_code_comments(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index ac9a48894..4ca0c0b63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "aiohttp~=3.8.4", "atlassian-python-api==3.39.0", "GitPython~=3.1.32", - "starlette-context==0.3.6" + "starlette-context==0.3.6", + "litellm~=0.1.351" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 51dc6feed..07a33514d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ python-gitlab==3.15.0 pytest~=7.4.0 aiohttp~=3.8.4 atlassian-python-api==3.39.0 -GitPython~=3.1.32 \ No newline at end of file +GitPython~=3.1.32 +litellm~=0.1.351 \ No newline at end of file