diff --git a/openai_functions/conversation.py b/openai_functions/conversation.py index 280438d..65dec8c 100644 --- a/openai_functions/conversation.py +++ b/openai_functions/conversation.py @@ -1,8 +1,10 @@ """A module for running OpenAI functions""" from __future__ import annotations +import time from typing import Any, Callable, Literal, TYPE_CHECKING, overload import openai +from openai.error import RateLimitError from .functions.union import UnionSkillSet from .openai_types import ( @@ -99,40 +101,90 @@ def clear_messages(self) -> None: @overload def _generate_message( - self, function_call: ForcedFunctionCall + self, function_call: ForcedFunctionCall, retries: int | None = 1 ) -> IntermediateResponseMessageType: ... @overload def _generate_message( - self, function_call: Literal["none"] + self, function_call: Literal["none"], retries: int | None = 1 ) -> FinalResponseMessageType: ... @overload def _generate_message( - self, function_call: Literal["auto"] = "auto" + self, function_call: Literal["auto"] = "auto", retries: int | None = 1 ) -> NonFunctionMessageType: ... def _generate_message( - self, function_call: OpenAiFunctionCallInput = "auto" + self, function_call: OpenAiFunctionCallInput = "auto", retries: int | None = 1 ) -> NonFunctionMessageType: - """Generate a response + """Generate a response, retrying if necessary Args: function_call (OpenAiFunctionCallInput): The function call. + retries (int | None): The number of retries. Defaults to 4. + Will retry indefinitely if None. + + Raises: + openai.error.RateLimitError: If the rate limit is exceeded Returns: NonFunctionMessageType: The response """ - response = openai.ChatCompletion.create( + if retries is None: + retries = -1 + while True: + try: + response = self._generate_raw_message(function_call) + except RateLimitError as error: + if retries == 0: + raise + retries -= 1 + time.sleep(self._retry_time_from_headers(error.headers)) + else: + return response["choices"][0]["message"] # type: ignore + + def _parse_retry_time(self, wait_for: str) -> float: + """Parse the time returned by an x-ratelimit-reset-requests header + + Args: + wait_for (str): The time + + Returns: + float: The time to the next reset + """ + return float(wait_for[:-1]) * {"s": 1, "m": 60, "h": 3600}[wait_for[-1]] + + def _retry_time_from_headers(self, headers: dict[str, str]) -> float: + """Get the time returned by the headers of an 429 reply + + Args: + headers (dict[str, str]): The headers of the reply + + Returns: + float: The time to wait for before retrying + """ + return self._parse_retry_time(headers["x-ratelimit-reset-requests"]) / int( + headers["x-ratelimit-limit-requests"] + ) + + def _generate_raw_message(self, function_call: OpenAiFunctionCallInput) -> Any: + """Generate a raw OpenAI response + + Args: + function_call (OpenAiFunctionCallInput): The function call. + + Returns: + The raw OpenAI response + """ + return openai.ChatCompletion.create( model=self.model, messages=[message.as_dict() for message in self.messages], functions=self.functions_schema, function_call=function_call, ) - return response["choices"][0]["message"] # type: ignore def remove_function_call(self, function_name: str) -> None: """Remove a function call from the messages, if it is the last message @@ -246,7 +298,7 @@ def run_function_if_needed(self) -> bool: return self.run_function_and_substitute(function_call) def generate_message( - self, function_call: OpenAiFunctionCallInput = "auto" + self, function_call: OpenAiFunctionCallInput = "auto", retries: int | None = 1 ) -> GenericMessage: """Generate the next message. Will run a function if the last message was a function call and the function call is not being overridden; @@ -254,6 +306,8 @@ def generate_message( Args: function_call (OpenAiFunctionCallInput): The function call + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: GenericMessage: The response @@ -261,17 +315,19 @@ def generate_message( if function_call in ["auto", "none"] and self.run_function_if_needed(): return self.messages[-1] - message: NonFunctionMessageType = self._generate_message(function_call) + message: NonFunctionMessageType = self._generate_message(function_call, retries) self.add_message(message) return Message(message) def run_until_response( - self, allow_function_calls: bool = True + self, allow_function_calls: bool = True, retries: int | None = 1 ) -> FinalResponseMessage: """Run functions query the AI until a response is generated Args: allow_function_calls (bool): Whether to allow the AI to call functions + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: FinalResponseMessage: The final response, either from the AI or a function @@ -279,7 +335,8 @@ def run_until_response( """ while True: message = self.generate_message( - function_call="auto" if allow_function_calls else "none" + function_call="auto" if allow_function_calls else "none", + retries=retries, ) if is_final_response_message(message): return message @@ -378,17 +435,19 @@ def remove_function( """ self.skills.remove_function(function) - def ask(self, question: str) -> str: + def ask(self, question: str, retries: int | None = 1) -> str: """Ask the AI a question, running until a response is generated Args: question (str): The question + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: str: The answer to the question """ self.add_message(question) - return self.run_until_response().content + return self.run_until_response(retries=retries).content def add_skill(self, skill: FunctionSet) -> None: """Add a skill to those available to the AI @@ -398,12 +457,16 @@ def add_skill(self, skill: FunctionSet) -> None: """ self.skills.add_skill(skill) - def run(self, function: str, prompt: str | None = None) -> Any: + def run( + self, function: str, prompt: str | None = None, retries: int | None = 1 + ) -> Any: """Run a specified function and return the raw function result Args: function (str): The function to run prompt (str | None): The prompt to use + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: The raw function result @@ -412,5 +475,7 @@ def run(self, function: str, prompt: str | None = None) -> Any: self.add_message(prompt) # We can do type: ignore as we know we're forcing a function call response: FunctionCallMessage - response = self.generate_message({"name": function}) # type: ignore + response = self.generate_message( + {"name": function}, retries=retries + ) # type: ignore return self.skills(response.function_call) diff --git a/openai_functions/nlp.py b/openai_functions/nlp.py index b897a03..9f10f83 100644 --- a/openai_functions/nlp.py +++ b/openai_functions/nlp.py @@ -71,23 +71,27 @@ def _initialize_conversation(self) -> None: } ) - def from_natural_language(self, prompt: str) -> Return: + def from_natural_language(self, prompt: str, retries: int | None = 1) -> Return: """Run the function with the given natural language input Args: prompt (str): The prompt to use + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: The result of the original function """ self._initialize_conversation() - return self.conversation.run(self.openai_function.name, prompt) + return self.conversation.run(self.openai_function.name, prompt, retries=retries) - def natural_language_response(self, prompt: str) -> str: + def natural_language_response(self, prompt: str, retries: int | None = 1) -> str: """Run the function and respond to the user with natural language Args: prompt (str): The prompt to use + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: str: The response from the AI @@ -97,24 +101,28 @@ def natural_language_response(self, prompt: str) -> str: self.conversation.generate_message( function_call={"name": self.openai_function.name} ) - response = self.conversation.run_until_response(False) + response = self.conversation.run_until_response(False, retries=retries) return response.content def natural_language_annotated( - self, prompt: str + self, prompt: str, retries: int | None = 1 ) -> NaturalLanguageAnnotated[Return]: """Run the function and respond to the user with natural language as well as the raw function result Args: prompt (str): The prompt to use + retries (int | None): The number of retries; if None, will retry + indefinitely Returns: NaturalLanguageAnnotated: The response from the AI """ self._initialize_conversation() - function_result = self.conversation.run(self.openai_function.name, prompt) - response = self.conversation.run_until_response(False) + function_result = self.conversation.run( + self.openai_function.name, prompt, retries=retries + ) + response = self.conversation.run_until_response(False, retries=retries) return NaturalLanguageAnnotated(function_result, response.content) diff --git a/pyproject.toml b/pyproject.toml index 2c44a4e..60681ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openai-functions" -version = "0.6.9" +version = "0.6.10" description = "Simplifies the usage of OpenAI ChatGPT's function calling by generating the schemas and parsing OpenAI's responses for you." authors = ["rizerphe <44440399+rizerphe@users.noreply.github.com>"] readme = "README.md"