Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for Anthropic, Cohere, Replicate, Azure #172

Merged
merged 5 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
}
17 changes: 12 additions & 5 deletions pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry

import litellm
from litellm import acompletion
from pr_agent.config_loader import get_settings

OPENAI_RETRIES=5
Expand All @@ -22,6 +23,7 @@ def __init__(self):
"""
try:
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
Expand All @@ -31,6 +33,9 @@ def __init__(self):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
litellm.api_base = get_settings().openai.api_base
if get_settings().get("LITE.KEY", None):
self.llm_api_key = get_settings().lite.key
except AttributeError as e:
raise ValueError("OpenAI key is required") from e

Expand All @@ -57,14 +62,15 @@ async def chat_completion(self, model: str, temperature: float, system: str, use
TryAgain: If there is an attribute error during OpenAI inference.
"""
try:
response = await openai.ChatCompletion.acreate(
response = await acompletion(
model=model,
deployment_id=self.deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
temperature=temperature,
api_key=self.llm_api_key
)
except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e)
Expand All @@ -75,8 +81,9 @@ async def chat_completion(self, model: str, temperature: float, system: str, use
except (Exception) as e:
logging.error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
if response is None or len(response.choices) == 0:
if response is None or len(response["choices"]) == 0:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
print(resp, finish_reason)
return resp, finish_reason
4 changes: 2 additions & 2 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

import traceback
import logging
from typing import Callable, Tuple

Expand Down Expand Up @@ -221,6 +221,6 @@ async def retry_with_fallback_models(f: Callable):
try:
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}")
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception
5 changes: 2 additions & 3 deletions pr_agent/algo/token_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from tiktoken import encoding_for_model, get_encoding

from pr_agent.config_loader import get_settings

Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(self, pr, vars: dict, system, user):
- system: The system string.
- user: The user string.
"""
self.encoder = encoding_for_model(get_settings().config.model)
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)

def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
Expand All @@ -47,7 +47,6 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)

system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens
Expand Down
9 changes: 6 additions & 3 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
# See README for details about GitHub App deployment.

[openai]
key = "<API_KEY>" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out.
key = "" # Acquire through https://platform.openai.com
#org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI
#api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine

[lite]
key = "YOUR_LLM_API_KEY" # Optional, use this if you'd like to use Anthropic, Llama2 (Replicate), or Cohere models
[github]
# ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
user_token = "" # A GitHub personal access token with 'repo' scope.
deployment_type = "user" #set to user by default

# ---- Set the following only for deployment type == "app", see README for details.
private_key = """\
Expand Down
27 changes: 16 additions & 11 deletions pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,28 @@ def _prepare_pr_review(self) -> str:
the feedback.
"""
review = self.prediction.strip()

print(f"review: {review}")
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
data = try_fix_json(review)

print(f"data: {data}")
# Move 'Security concerns' key to 'PR Analysis' section for better display
if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']:
val = data['PR Feedback']['Security concerns']
del data['PR Feedback']['Security concerns']
data['PR Analysis']['Security concerns'] = val

# Filter out code suggestions that can be submitted as inline comments
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
]
if not data['PR Feedback']['Code suggestions']:
del data['PR Feedback']['Code suggestions']
if 'PR Feedback' in data:
if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
]
if not data['PR Feedback']['Code suggestions']:
del data['PR Feedback']['Code suggestions']

# Add incremental review section
if self.incremental.is_incremental:
Expand All @@ -205,7 +206,11 @@ def _prepare_pr_review(self) -> str:
# Log markdown response if verbosity level is high
if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}")


if markdown_text == None or len(markdown_text) == 0:
markdown_text = review

print(f"markdown text: {markdown_text}")
return markdown_text

def _publish_inline_code_comments(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ dependencies = [
"aiohttp~=3.8.4",
"atlassian-python-api==3.39.0",
"GitPython~=3.1.32",
"starlette-context==0.3.6"
"starlette-context==0.3.6",
"litellm==0.1.2291"
]

[project.urls]
Expand Down