From e6e507d7295fce5db65cc26aeda2e54485de8f2a Mon Sep 17 00:00:00 2001 From: rizerphe <44440399+rizerphe@users.noreply.github.com> Date: Mon, 3 Jul 2023 19:18:02 +0300 Subject: [PATCH] Add separate remove_call property --- docs/conversation.md | 4 +- docs/skills.md | 8 ++ openai_functions/conversation.py | 90 +++++++++++++--------- openai_functions/functions/basic_set.py | 4 +- openai_functions/functions/disabled_set.py | 0 openai_functions/functions/functions.py | 8 +- openai_functions/functions/sets.py | 9 ++- openai_functions/functions/wrapper.py | 12 +++ pyproject.toml | 2 +- 9 files changed, 96 insertions(+), 41 deletions(-) create mode 100644 openai_functions/functions/disabled_set.py diff --git a/docs/conversation.md b/docs/conversation.md index 3301603..7adda40 100644 --- a/docs/conversation.md +++ b/docs/conversation.md @@ -40,6 +40,7 @@ def my_awesome_function(...): @conversation.add_function( save_return=True, serialize=False, + remove_call=False, interpret_as_response=False ) def my_amazing_function(): @@ -54,7 +55,8 @@ The arguments passed to `add_function` are the same as those an [OpenAIFunction] - `save_return` - whether to send the return value of the function back to the AI; some functions - mainly those that don't return anything - don't need to do this - `serialize` - whether to serialize the function's return value before sending the result back to the AI; openai expects a function call to be a string, so if this is False, the result of the function execution should be a string. Otherwise, it will use JSON serialization, so if `serialize` is set to True, the function return needs to be JSON-serializable -- `interpret_as_response` - whether to interpret the return value of the function (the serialized one if `serialize` is set to True) as the response from the AI, replacing the function call +- `remove_call` - whether to remove the function call message itself; be careful to avoid infinite loops when using with `save_return=False`; the function should then, for example, disappear from the schema; it's your responsibility to make sure this happens +- `interpret_as_response` - whether to interpret the return value of the function (the serialized one if `serialize` is set to True) as the response from the AI You can read more about how to use skills [here](skills). diff --git a/docs/skills.md b/docs/skills.md index 6a31927..c9e4403 100644 --- a/docs/skills.md +++ b/docs/skills.md @@ -12,6 +12,7 @@ def get_current_weather(location: str) -> dict: @skill.add_function( save_return=True, serialize=False, + remove_call=False, interpret_as_response=True ) def set_weather(location: str, weather_description: str): @@ -20,6 +21,13 @@ def set_weather(location: str, weather_description: str): schema = skill.functions_schema ``` +The parameters here are: + +- `save_return` - whether to send the return value of the function back to the AI; some functions - mainly those that don't return anything - don't need to do this +- `serialize` - whether to serialize the function's return value before sending the result back to the AI; openai expects a function call to be a string, so if this is False, the result of the function execution should be a string. Otherwise, it will use JSON serialization, so if `serialize` is set to True, the function return needs to be JSON-serializable +- `remove_call` - whether to remove the function call message itself; be careful to avoid infinite loops when using with `save_return=False`; the function should then, for example, disappear from the schema; it's your responsibility to make sure this happens +- `interpret_as_response` - whether to interpret the return value of the function (the serialized one if `serialize` is set to True) as the response from the AI + `schema` will be a list of JSON objects ready to be sent to OpenAI. You can then call your functions directly with the response returned from OpenAI: ```python diff --git a/openai_functions/conversation.py b/openai_functions/conversation.py index 61fdba4..41b50b3 100644 --- a/openai_functions/conversation.py +++ b/openai_functions/conversation.py @@ -134,20 +134,19 @@ def _generate_message( ) return response["choices"][0]["message"] # type: ignore - def substitute_last_with_function_result(self, result: str) -> None: - """Substitute the last message with the result of a function + def remove_function_call(self, function_name: str) -> None: + """Remove a function call from the messages, if it is the last message Args: - result (str): The function result + function_name (str): The function name """ - self.pop_message() - response: NonFunctionMessageType = { - "role": "assistant", - "content": result, - } - self.add_message(response) + if ( + self.messages[-1].function_call + and self.messages[-1].function_call["name"] == function_name + ): + self.pop_message() - def add_function_result(self, function_result: FunctionResult) -> bool: + def _add_function_result(self, function_result: FunctionResult) -> bool: """Add a function execution result to the chat Args: @@ -159,41 +158,56 @@ def add_function_result(self, function_result: FunctionResult) -> bool: """ if function_result.content is None: return False + if function_result.interpret_return_as_response: + self._add_function_result_as_response(function_result.content) + else: + self._add_function_result_as_function_call(function_result) + return True + + def _add_function_result_as_response(self, function_result: str) -> None: + """Add a function execution result to the chat as an assistant response + + Args: + function_result (str): The function execution result + """ + response: FinalResponseMessageType = { + "role": "assistant", + "content": function_result, + } + self.add_message(response) + + def _add_function_result_as_function_call( + self, function_result: FunctionResult + ) -> None: + """Add a function execution result to the chat as a function call + + Args: + function_result (FunctionResult): The function execution result + """ response: FunctionMessageType = { "role": "function", "name": function_result.name, "content": function_result.content, } self.add_message(response) - return True - def add_or_substitute_function_result( - self, function_result: FunctionResult - ) -> bool: - """Add or substitute a function execution result + def add_function_result(self, function_result: FunctionResult) -> bool: + """Add a function execution result - If the function has interpret_as_response set to True, the last message, - which is assumed to be a function call, will be replaced with the function - execution result. Otherwise, the function execution result will be added - to the chat. + If the function has a return value (save_return is True), it will be added to + the chat. The function call will be removed depending on the remove_call + attribute, and the function result will be interpreted as a response or a + function call depending on the interpret_return_as_response attribute. Args: function_result (FunctionResult): The function result - Raises: - TypeError: If the function returns a None value - Returns: bool: Whether the function result was added """ - if function_result.substitute: - if function_result.content is None: - raise TypeError( - f"Function {function_result.name} did not provide a return" - ) - self.substitute_last_with_function_result(function_result.content) - return True - return self.add_function_result(function_result) + if function_result.remove_call: + self.remove_function_call(function_result.name) + return self._add_function_result(function_result) def run_function_and_substitute( self, @@ -211,9 +225,7 @@ def run_function_and_substitute( bool: Whether the function result was added to the chat (whether save_return was True) """ - return self.add_or_substitute_function_result( - self.skills.run_function(function_call) - ) + return self.add_function_result(self.skills.run_function(function_call)) def run_function_if_needed(self) -> bool: """Run a function if the last message was a function call @@ -283,6 +295,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> Callable[..., JsonType]: ... @@ -293,6 +306,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> Callable[[Callable[..., JsonType]], Callable[..., JsonType]]: ... @@ -303,6 +317,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> ( Callable[[Callable[..., JsonType]], Callable[..., JsonType]] @@ -315,10 +330,11 @@ def add_function( save_return (bool): Whether to send the return value of this function back to the AI. Defaults to True. serialize (bool): Whether to serialize the return value of this function. - Defaults to True. Otherwise, the return value must be a string. + Otherwise, the return value must be a string. + remove_call (bool): Whether to remove the function call itself from the chat + history interpret_as_response (bool): Whether to interpret the return value of this - function as a response of the agent, replacing the function call - message. Defaults to False. + function as the natural language response of the AI. Returns: Callable[[Callable[..., JsonType]], Callable[..., JsonType]]: A decorator @@ -328,12 +344,14 @@ def add_function( return self.skills.add_function( save_return=save_return, serialize=serialize, + remove_call=remove_call, interpret_as_response=interpret_as_response, ) return self.skills.add_function( function, save_return=save_return, serialize=serialize, + remove_call=remove_call, interpret_as_response=interpret_as_response, ) diff --git a/openai_functions/functions/basic_set.py b/openai_functions/functions/basic_set.py index be3f675..df22a01 100644 --- a/openai_functions/functions/basic_set.py +++ b/openai_functions/functions/basic_set.py @@ -50,7 +50,9 @@ def run_function(self, input_data: FunctionCall) -> FunctionResult: """ function = self.find_function(input_data["name"]) result = self.get_function_result(function, json.loads(input_data["arguments"])) - return FunctionResult(function.name, result, function.interpret_as_response) + return FunctionResult( + function.name, result, function.remove_call, function.interpret_as_response + ) def find_function(self, function_name: str) -> OpenAIFunction: """Find a function in the skillset diff --git a/openai_functions/functions/disabled_set.py b/openai_functions/functions/disabled_set.py new file mode 100644 index 0000000..e69de29 diff --git a/openai_functions/functions/functions.py b/openai_functions/functions/functions.py index dc486c7..52f8939 100644 --- a/openai_functions/functions/functions.py +++ b/openai_functions/functions/functions.py @@ -41,6 +41,11 @@ def serialize(self) -> bool: """Get whether to continue running after this function""" ... # pylint: disable=unnecessary-ellipsis + @property + def remove_call(self) -> bool: + """Get whether to remove the call to this function from the chat history""" + ... # pylint: disable=unnecessary-ellipsis + @property def interpret_as_response(self) -> bool: """Get whether to interpret the return value of this function as a response""" @@ -77,7 +82,8 @@ class FunctionResult: name: str raw_result: RawFunctionResult | None - substitute: bool = False + remove_call: bool = False + interpret_return_as_response: bool = False @property def content(self) -> str | None: diff --git a/openai_functions/functions/sets.py b/openai_functions/functions/sets.py index a476c67..394fdbe 100644 --- a/openai_functions/functions/sets.py +++ b/openai_functions/functions/sets.py @@ -61,6 +61,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> Callable[..., JsonType]: ... @@ -71,6 +72,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> Callable[[Callable[..., JsonType]], Callable[..., JsonType]]: ... @@ -81,6 +83,7 @@ def add_function( *, save_return: bool = True, serialize: bool = True, + remove_call: bool = False, interpret_as_response: bool = False, ) -> ( Callable[[Callable[..., JsonType]], Callable[..., JsonType]] @@ -95,6 +98,8 @@ def add_function( serialize (bool): Whether to serialize the return value of this function. Defaults to True. Otherwise, the return value must be a string. + remove_call (bool): Whether to remove the function call from the AI's + chat history. Defaults to False. interpret_as_response (bool): Whether to interpret the return value of this function as a response of the agent. Defaults to False. @@ -109,7 +114,9 @@ def add_function( self._add_function( FunctionWrapper( function, - WrapperConfig(None, save_return, serialize, interpret_as_response), + WrapperConfig( + None, save_return, serialize, remove_call, interpret_as_response + ), ) ) return function diff --git a/openai_functions/functions/wrapper.py b/openai_functions/functions/wrapper.py index 24b6c45..7bcd353 100644 --- a/openai_functions/functions/wrapper.py +++ b/openai_functions/functions/wrapper.py @@ -24,6 +24,8 @@ class WrapperConfig: save_return (bool): Whether to send the return value back to the AI serialize (bool): Whether to serialize the return value; if False, the return value must be a string + remove_call (bool): Whether to remove the call to this function from the + chat history interpret_as_response (bool): Whether to interpret the return value as a response from the agent directly, or to base the response on the return value @@ -32,6 +34,7 @@ class WrapperConfig: parsers: list[Type[ArgSchemaParser]] | None = None save_return: bool = True serialize: bool = True + remove_call: bool = False interpret_as_response: bool = False @@ -91,6 +94,15 @@ def serialize(self) -> bool: """ return self.config.serialize + @property + def remove_call(self) -> bool: + """Get whether to remove the call to this function from the chat history + + Returns: + bool: Whether to remove the call to this function from the chat history + """ + return self.config.remove_call + @property def interpret_as_response(self) -> bool: """Get whether to interpret the return value as an assistant response diff --git a/pyproject.toml b/pyproject.toml index 160d503..400aaa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openai-functions" -version = "0.6.5" +version = "0.6.7" description = "Simplifies the usage of OpenAI ChatGPT's function calling by generating the schemas and parsing OpenAI's responses for you." authors = ["rizerphe <44440399+rizerphe@users.noreply.github.com>"] readme = "README.md"