Skip to content

Commit

Permalink
Merge pull request Codium-ai#183 from zmeir/zmeir-fallback_deployments
Browse files Browse the repository at this point in the history
Support fallback deployments to accompany fallback models
  • Loading branch information
okotek authored Aug 14, 2023
2 parents 684ba82 + cfb741b commit 3ade5b3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 9 deletions.
16 changes: 14 additions & 2 deletions pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self):
self.azure = False
if get_settings().get("OPENAI.ORG", None):
litellm.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
Expand All @@ -47,6 +46,13 @@ def __init__(self):
except AttributeError as e:
raise ValueError("OpenAI key is required") from e

@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
Expand All @@ -70,9 +76,15 @@ async def chat_completion(self, model: str, temperature: float, system: str, use
TryAgain: If there is an attribute error during OpenAI inference.
"""
try:
deployment_id = self.deployment_id
if get_settings().config.verbosity_level >= 2:
logging.debug(
f"Generating completion with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
)
response = await acompletion(
model=model,
deployment_id=self.deployment_id,
deployment_id=deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
Expand Down
41 changes: 34 additions & 7 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,47 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo


async def retry_with_fallback_models(f: Callable):
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [fallback_models]
all_models = [model] + fallback_models
for i, model in enumerate(all_models):
all_models = _get_all_models()
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
try:
get_settings().set("openai.deployment_id", deployment_id)
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
logging.warning(
f"Failed to generate prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}: "
f"{traceback.format_exc()}"
)
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception


def _get_all_models() -> List[str]:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
return all_models


def _get_all_deployments(all_models: List[str]) -> List[str]:
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})")
else:
all_deployments = [deployment_id] * len(all_models)
return all_deployments


def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
Expand Down
1 change: 1 addition & 0 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ key = "" # Acquire through https://platform.openai.com
#api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "" # The deployment name you chose when you deployed the engine
#fallback_deployments = [] # For each fallback model specified in configuration.toml in the [config] section, specify the appropriate deployment_id

[anthropic]
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
Expand Down

0 comments on commit 3ade5b3

Please sign in to comment.