diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index ed50818630..d08c250b3e 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -1,5 +1,7 @@ +import logging from typing import Any, Optional, Union +import openai from openai import OpenAI from evals.api import CompletionFn, CompletionResult @@ -12,12 +14,44 @@ Prompt, ) from evals.record import record_sampling -from evals.utils.api_utils import ( - openai_chat_completion_create_retrying, - openai_completion_create_retrying, +from evals.utils.api_utils import create_retrying + +OPENAI_TIMEOUT_EXCEPTIONS = ( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, ) +def openai_completion_create_retrying(client: OpenAI, *args, **kwargs): + """ + Helper function for creating a completion. + `args` and `kwargs` match what is accepted by `openai.Completion.create`. + """ + result = create_retrying( + client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs + ) + if "error" in result: + logging.warning(result) + raise openai.APIError(result["error"]) + return result + + +def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs): + """ + Helper function for creating a completion. + `args` and `kwargs` match what is accepted by `openai.Completion.create`. + """ + result = create_retrying( + client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs + ) + if "error" in result: + logging.warning(result) + raise openai.APIError(result["error"]) + return result + + class OpenAIBaseCompletionResult(CompletionResult): def __init__(self, raw_data: Any, prompt: Any): self.raw_data = raw_data @@ -82,7 +116,7 @@ def __call__( openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() result = openai_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), + client=OpenAI(api_key=self.api_key, base_url=self.api_base), model=self.model, prompt=openai_create_prompt, **{**kwargs, **self.extra_options}, @@ -127,7 +161,7 @@ def __call__( openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() result = openai_chat_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), + client=OpenAI(api_key=self.api_key, base_url=self.api_base), model=self.model, messages=openai_create_prompt, **{**kwargs, **self.extra_options}, diff --git a/evals/solvers/providers/anthropic/anthropic_solver.py b/evals/solvers/providers/anthropic/anthropic_solver.py index 9f0766598d..bb7fe50e24 100644 --- a/evals/solvers/providers/anthropic/anthropic_solver.py +++ b/evals/solvers/providers/anthropic/anthropic_solver.py @@ -1,20 +1,25 @@ from typing import Any, Optional, Union -from evals.solvers.solver import Solver, SolverResult -from evals.task_state import TaskState, Message -from evals.record import record_sampling -from evals.utils.api_utils import request_with_timeout - import anthropic from anthropic import Anthropic from anthropic.types import ContentBlock, MessageParam, Usage -import backoff + +from evals.record import record_sampling +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState +from evals.utils.api_utils import create_retrying oai_to_anthropic_role = { "system": "user", "user": "user", "assistant": "assistant", } +ANTHROPIC_TIMEOUT_EXCEPTIONS = ( + anthropic.RateLimitError, + anthropic.APIConnectionError, + anthropic.APITimeoutError, + anthropic.InternalServerError, +) class AnthropicSolver(Solver): @@ -59,9 +64,7 @@ def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: ) # for logging purposes: prepend the task desc to the orig msgs as a system message - orig_msgs.insert( - 0, Message(role="system", content=task_state.task_description).to_dict() - ) + orig_msgs.insert(0, Message(role="system", content=task_state.task_description).to_dict()) record_sampling( prompt=orig_msgs, # original message format, supported by our logviz sampled=[solver_result.output], @@ -113,23 +116,14 @@ def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam] return alt_msgs -@backoff.on_exception( - wait_gen=backoff.expo, - exception=( - anthropic.RateLimitError, - anthropic.APIConnectionError, - anthropic.APITimeoutError, - anthropic.InternalServerError, - ), - max_value=60, - factor=1.5, -) def anthropic_create_retrying(client: Anthropic, *args, **kwargs): """ Helper function for creating a backoff-retry enabled message request. `args` and `kwargs` match what is accepted by `client.messages.create`. """ - result = request_with_timeout(client.messages.create, *args, **kwargs) + result = create_retrying( + client.messages.create, retry_exceptions=ANTHROPIC_TIMEOUT_EXCEPTIONS, *args, **kwargs + ) if "error" in result: raise Exception(result["error"]) return result diff --git a/evals/utils/api_utils.py b/evals/utils/api_utils.py index 7479d5e9a2..a45476c9af 100644 --- a/evals/utils/api_utils.py +++ b/evals/utils/api_utils.py @@ -1,73 +1,22 @@ -""" -This file defines various helper functions for interacting with the OpenAI API. -""" -import concurrent import logging import os import backoff -import openai -from openai import OpenAI EVALS_THREAD_TIMEOUT = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40")) logging.getLogger("httpx").setLevel(logging.WARNING) # suppress "OK" logs from openai API calls -@backoff.on_exception( +@backoff.on_predicate( wait_gen=backoff.expo, - exception=( - openai.RateLimitError, - openai.APIConnectionError, - openai.APITimeoutError, - openai.InternalServerError, - ), max_value=60, factor=1.5, ) -def openai_completion_create_retrying(client: OpenAI, *args, **kwargs): +def create_retrying(func: callable, retry_exceptions: tuple[Exception], *args, **kwargs): """ - Helper function for creating a completion. - `args` and `kwargs` match what is accepted by `openai.Completion.create`. + Retries given function if one of given exceptions is raised """ - result = client.completions.create(*args, **kwargs) - if "error" in result: - logging.warning(result) - raise openai.error.APIError(result["error"]) - return result - - -def request_with_timeout(func, *args, timeout=EVALS_THREAD_TIMEOUT, **kwargs): - """ - Worker thread for making a single request within allotted time. - """ - while True: - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(func, *args, **kwargs) - try: - result = future.result(timeout=timeout) - return result - except concurrent.futures.TimeoutError: - continue - - -@backoff.on_exception( - wait_gen=backoff.expo, - exception=( - openai.RateLimitError, - openai.APIConnectionError, - openai.APITimeoutError, - openai.InternalServerError, - ), - max_value=60, - factor=1.5, -) -def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs): - """ - Helper function for creating a chat completion. - `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. - """ - result = request_with_timeout(client.chat.completions.create, *args, **kwargs) - if "error" in result: - logging.warning(result) - raise openai.error.APIError(result["error"]) - return result + try: + return func(*args, **kwargs) + except retry_exceptions: + return False