Skip to content

Commit

Permalink
Add rate limit handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rizerphe committed Jul 3, 2023
1 parent 72e97b8 commit a90ee0a
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 23 deletions.
95 changes: 80 additions & 15 deletions openai_functions/conversation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""A module for running OpenAI functions"""
from __future__ import annotations
import time
from typing import Any, Callable, Literal, TYPE_CHECKING, overload

import openai
from openai.error import RateLimitError

from .functions.union import UnionSkillSet
from .openai_types import (
Expand Down Expand Up @@ -99,40 +101,90 @@ def clear_messages(self) -> None:

@overload
def _generate_message(
self, function_call: ForcedFunctionCall
self, function_call: ForcedFunctionCall, retries: int | None = 1
) -> IntermediateResponseMessageType:
...

@overload
def _generate_message(
self, function_call: Literal["none"]
self, function_call: Literal["none"], retries: int | None = 1
) -> FinalResponseMessageType:
...

@overload
def _generate_message(
self, function_call: Literal["auto"] = "auto"
self, function_call: Literal["auto"] = "auto", retries: int | None = 1
) -> NonFunctionMessageType:
...

def _generate_message(
self, function_call: OpenAiFunctionCallInput = "auto"
self, function_call: OpenAiFunctionCallInput = "auto", retries: int | None = 1
) -> NonFunctionMessageType:
"""Generate a response
"""Generate a response, retrying if necessary
Args:
function_call (OpenAiFunctionCallInput): The function call.
retries (int | None): The number of retries. Defaults to 4.
Will retry indefinitely if None.
Raises:
openai.error.RateLimitError: If the rate limit is exceeded
Returns:
NonFunctionMessageType: The response
"""
response = openai.ChatCompletion.create(
if retries is None:
retries = -1
while True:
try:
response = self._generate_raw_message(function_call)
except RateLimitError as error:
if retries == 0:
raise
retries -= 1
time.sleep(self._retry_time_from_headers(error.headers))
else:
return response["choices"][0]["message"] # type: ignore

def _parse_retry_time(self, wait_for: str) -> float:
"""Parse the time returned by an x-ratelimit-reset-requests header
Args:
wait_for (str): The time
Returns:
float: The time to the next reset
"""
return float(wait_for[:-1]) * {"s": 1, "m": 60, "h": 3600}[wait_for[-1]]

def _retry_time_from_headers(self, headers: dict[str, str]) -> float:
"""Get the time returned by the headers of an 429 reply
Args:
headers (dict[str, str]): The headers of the reply
Returns:
float: The time to wait for before retrying
"""
return self._parse_retry_time(headers["x-ratelimit-reset-requests"]) / int(
headers["x-ratelimit-limit-requests"]
)

def _generate_raw_message(self, function_call: OpenAiFunctionCallInput) -> Any:
"""Generate a raw OpenAI response
Args:
function_call (OpenAiFunctionCallInput): The function call.
Returns:
The raw OpenAI response
"""
return openai.ChatCompletion.create(
model=self.model,
messages=[message.as_dict() for message in self.messages],
functions=self.functions_schema,
function_call=function_call,
)
return response["choices"][0]["message"] # type: ignore

def remove_function_call(self, function_name: str) -> None:
"""Remove a function call from the messages, if it is the last message
Expand Down Expand Up @@ -246,40 +298,45 @@ def run_function_if_needed(self) -> bool:
return self.run_function_and_substitute(function_call)

def generate_message(
self, function_call: OpenAiFunctionCallInput = "auto"
self, function_call: OpenAiFunctionCallInput = "auto", retries: int | None = 1
) -> GenericMessage:
"""Generate the next message. Will run a function if the last message
was a function call and the function call is not being overridden;
if the function does not save the return a message will still be generated.
Args:
function_call (OpenAiFunctionCallInput): The function call
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
GenericMessage: The response
"""
if function_call in ["auto", "none"] and self.run_function_if_needed():
return self.messages[-1]

message: NonFunctionMessageType = self._generate_message(function_call)
message: NonFunctionMessageType = self._generate_message(function_call, retries)
self.add_message(message)
return Message(message)

def run_until_response(
self, allow_function_calls: bool = True
self, allow_function_calls: bool = True, retries: int | None = 1
) -> FinalResponseMessage:
"""Run functions query the AI until a response is generated
Args:
allow_function_calls (bool): Whether to allow the AI to call functions
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
FinalResponseMessage: The final response, either from the AI or a function
that has interpret_as_response set to True
"""
while True:
message = self.generate_message(
function_call="auto" if allow_function_calls else "none"
function_call="auto" if allow_function_calls else "none",
retries=retries,
)
if is_final_response_message(message):
return message
Expand Down Expand Up @@ -378,17 +435,19 @@ def remove_function(
"""
self.skills.remove_function(function)

def ask(self, question: str) -> str:
def ask(self, question: str, retries: int | None = 1) -> str:
"""Ask the AI a question, running until a response is generated
Args:
question (str): The question
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
str: The answer to the question
"""
self.add_message(question)
return self.run_until_response().content
return self.run_until_response(retries=retries).content

def add_skill(self, skill: FunctionSet) -> None:
"""Add a skill to those available to the AI
Expand All @@ -398,12 +457,16 @@ def add_skill(self, skill: FunctionSet) -> None:
"""
self.skills.add_skill(skill)

def run(self, function: str, prompt: str | None = None) -> Any:
def run(
self, function: str, prompt: str | None = None, retries: int | None = 1
) -> Any:
"""Run a specified function and return the raw function result
Args:
function (str): The function to run
prompt (str | None): The prompt to use
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
The raw function result
Expand All @@ -412,5 +475,7 @@ def run(self, function: str, prompt: str | None = None) -> Any:
self.add_message(prompt)
# We can do type: ignore as we know we're forcing a function call
response: FunctionCallMessage
response = self.generate_message({"name": function}) # type: ignore
response = self.generate_message(
{"name": function}, retries=retries
) # type: ignore
return self.skills(response.function_call)
22 changes: 15 additions & 7 deletions openai_functions/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,27 @@ def _initialize_conversation(self) -> None:
}
)

def from_natural_language(self, prompt: str) -> Return:
def from_natural_language(self, prompt: str, retries: int | None = 1) -> Return:
"""Run the function with the given natural language input
Args:
prompt (str): The prompt to use
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
The result of the original function
"""
self._initialize_conversation()
return self.conversation.run(self.openai_function.name, prompt)
return self.conversation.run(self.openai_function.name, prompt, retries=retries)

def natural_language_response(self, prompt: str) -> str:
def natural_language_response(self, prompt: str, retries: int | None = 1) -> str:
"""Run the function and respond to the user with natural language
Args:
prompt (str): The prompt to use
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
str: The response from the AI
Expand All @@ -97,24 +101,28 @@ def natural_language_response(self, prompt: str) -> str:
self.conversation.generate_message(
function_call={"name": self.openai_function.name}
)
response = self.conversation.run_until_response(False)
response = self.conversation.run_until_response(False, retries=retries)
return response.content

def natural_language_annotated(
self, prompt: str
self, prompt: str, retries: int | None = 1
) -> NaturalLanguageAnnotated[Return]:
"""Run the function and respond to the user with natural language as well as
the raw function result
Args:
prompt (str): The prompt to use
retries (int | None): The number of retries; if None, will retry
indefinitely
Returns:
NaturalLanguageAnnotated: The response from the AI
"""
self._initialize_conversation()
function_result = self.conversation.run(self.openai_function.name, prompt)
response = self.conversation.run_until_response(False)
function_result = self.conversation.run(
self.openai_function.name, prompt, retries=retries
)
response = self.conversation.run_until_response(False, retries=retries)
return NaturalLanguageAnnotated(function_result, response.content)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openai-functions"
version = "0.6.9"
version = "0.6.10"
description = "Simplifies the usage of OpenAI ChatGPT's function calling by generating the schemas and parsing OpenAI's responses for you."
authors = ["rizerphe <44440399+rizerphe@users.noreply.github.com>"]
readme = "README.md"
Expand Down

0 comments on commit a90ee0a

Please sign in to comment.