From cd54db6161c44608e2e23efcc11c24c01a90a39d Mon Sep 17 00:00:00 2001 From: rizerphe <44440399+rizerphe@users.noreply.github.com> Date: Mon, 3 Jul 2023 21:52:37 +0300 Subject: [PATCH] Add proper invalid response handling --- openai_functions/__init__.py | 6 +++ openai_functions/exceptions.py | 48 +++++++++++++++++++ openai_functions/functions/basic_set.py | 8 +++- openai_functions/functions/wrapper.py | 19 +++++--- openai_functions/parsers/abc.py | 9 +++- .../parsers/atomic_type_parser.py | 3 +- openai_functions/parsers/dataclass_parser.py | 43 +++++++++++++---- openai_functions/parsers/dict_parser.py | 3 +- openai_functions/parsers/enum_parser.py | 8 +++- openai_functions/parsers/float_parser.py | 3 +- openai_functions/parsers/int_parser.py | 3 +- openai_functions/parsers/list_parser.py | 3 +- openai_functions/parsers/none_parser.py | 3 +- openai_functions/parsers/union_parser.py | 5 +- pyproject.toml | 2 +- tests/test_function_wrapper.py | 16 +++---- 16 files changed, 144 insertions(+), 38 deletions(-) diff --git a/openai_functions/__init__.py b/openai_functions/__init__.py index ac6a7f1..e035f97 100644 --- a/openai_functions/__init__.py +++ b/openai_functions/__init__.py @@ -1,7 +1,10 @@ """ChatGPT function calling based on function docstrings.""" from .conversation import Conversation from .exceptions import ( + BrokenSchemaError, + CannotParseTypeError, FunctionNotFoundError, + InvalidJsonError, NonSerializableOutputError, OpenAIFunctionsError, ) @@ -23,7 +26,10 @@ __all__ = [ "Conversation", + "BrokenSchemaError", + "CannotParseTypeError", "FunctionNotFoundError", + "InvalidJsonError", "NonSerializableOutputError", "OpenAIFunctionsError", "BasicFunctionSet", diff --git a/openai_functions/exceptions.py b/openai_functions/exceptions.py index 1de385c..93432b8 100644 --- a/openai_functions/exceptions.py +++ b/openai_functions/exceptions.py @@ -3,6 +3,8 @@ from typing import Any +from .json_type import JsonType + class OpenAIFunctionsError(Exception): """The base exception for all OpenAI Functions errors.""" @@ -12,6 +14,19 @@ class FunctionNotFoundError(OpenAIFunctionsError): """The function was not found in the given skillset.""" +class CannotParseTypeError(OpenAIFunctionsError): + """This type of the argument could not be parsed.""" + + def __init__(self, argtype: Any) -> None: + """Initialize the CannotParseTypeError. + + Args: + argtype (Any): The type that could not be parsed + """ + super().__init__(f"Cannot parse type {argtype}") + self.argtype = argtype + + class NonSerializableOutputError(OpenAIFunctionsError): """The function returned a non-serializable output.""" @@ -26,3 +41,36 @@ def __init__(self, result: Any) -> None: "Set serialize=False to use str() instead." ) self.result = result + + +class InvalidJsonError(OpenAIFunctionsError): + """OpenAI returned invalid JSON for the arguments.""" + + def __init__(self, response: str) -> None: + """Initialize the InvalidJsonError. + + Args: + response (str): The response that was not valid JSON + """ + super().__init__( + f"OpenAI returned invalid (perhaps incomplete) JSON: {response}" + ) + self.response = response + + +class BrokenSchemaError(OpenAIFunctionsError): + """The OpenAI response did not match the schema.""" + + def __init__(self, response: JsonType, schema: JsonType) -> None: + """Initialize the BrokenSchemaError. + + Args: + response (JsonType): The response that did not match the schema + schema (JsonType): The schema that the response did not match + """ + super().__init__( + "OpenAI returned a response that did not match the schema: " + f"{response!r} does not match {schema}" + ) + self.response = response + self.schema = schema diff --git a/openai_functions/functions/basic_set.py b/openai_functions/functions/basic_set.py index a40be72..2db45fe 100644 --- a/openai_functions/functions/basic_set.py +++ b/openai_functions/functions/basic_set.py @@ -3,7 +3,7 @@ import json from typing import TYPE_CHECKING -from ..exceptions import FunctionNotFoundError +from ..exceptions import FunctionNotFoundError, InvalidJsonError from .functions import FunctionResult, OpenAIFunction, RawFunctionResult from .sets import MutableFunctionSet @@ -49,7 +49,11 @@ def run_function(self, input_data: FunctionCall) -> FunctionResult: FunctionNotFoundError: If the function is not found """ function = self.find_function(input_data["name"]) - result = self.get_function_result(function, json.loads(input_data["arguments"])) + try: + arguments = json.loads(input_data["arguments"]) + except json.decoder.JSONDecodeError as e: + raise InvalidJsonError(input_data["arguments"]) from e + result = self.get_function_result(function, arguments) return FunctionResult( function.name, result, function.remove_call, function.interpret_as_response ) diff --git a/openai_functions/functions/wrapper.py b/openai_functions/functions/wrapper.py index 695e031..45f1abe 100644 --- a/openai_functions/functions/wrapper.py +++ b/openai_functions/functions/wrapper.py @@ -6,6 +6,7 @@ from docstring_parser import Docstring, parse +from ..exceptions import BrokenSchemaError, CannotParseTypeError from ..parsers import ArgSchemaParser, defargparsers if TYPE_CHECKING: @@ -221,7 +222,7 @@ def parse_argument(self, argument: inspect.Parameter) -> ArgSchemaParser: argument (inspect.Parameter): The argument to parse Raises: - TypeError: If the argument cannot be parsed + CannotParseTypeError: If the argument cannot be parsed Returns: ArgSchemaParser: The parser for the argument @@ -232,7 +233,7 @@ def parse_argument(self, argument: inspect.Parameter) -> ArgSchemaParser: for parser in self.parsers: if parser.can_parse(argument.annotation): return parser(argument.annotation, self.parsers) - raise TypeError(f"Cannot parse argument {argument}") + raise CannotParseTypeError(argument.annotation) def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, Any]: """Parse arguments @@ -243,10 +244,16 @@ def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, An Returns: OrderedDict[str, Any]: The parsed arguments """ - return OrderedDict( - (name, self.argument_parsers[name].parse_value(value)) - for name, value in arguments.items() - ) + argument_parsers = self.argument_parsers + if not all(name in arguments for name in argument_parsers): + raise BrokenSchemaError(arguments, self.arguments_schema) + try: + return OrderedDict( + (name, argument_parsers[name].parse_value(value)) + for name, value in arguments.items() + ) + except KeyError as e: + raise BrokenSchemaError(arguments, self.arguments_schema) from e def __call__(self, arguments: dict[str, JsonType]) -> Any: """Call the wrapped function diff --git a/openai_functions/parsers/abc.py b/openai_functions/parsers/abc.py index d9df461..c5db821 100644 --- a/openai_functions/parsers/abc.py +++ b/openai_functions/parsers/abc.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod from typing import Any, Generic, TYPE_CHECKING, Type, TypeVar +from ..exceptions import CannotParseTypeError + if TYPE_CHECKING: from ..json_type import JsonType from typing_extensions import TypeGuard @@ -34,12 +36,12 @@ def parse_rec(self, argtype: Type[S]) -> ArgSchemaParser[S]: ArgSchemaParser[S]: The parser for the type Raises: - ValueError: If the type cannot be parsed + CannotParseTypeError: If the type cannot be parsed """ for parser in self.rec_parsers: if parser.can_parse(argtype): return parser(argtype, self.rec_parsers) - raise ValueError(f"Cannot parse type {argtype}") + raise CannotParseTypeError(argtype) @classmethod @abstractmethod @@ -61,4 +63,7 @@ def parse_value(self, value: JsonType) -> T: Args: value (JsonType): The value to parse + + Raises: + BrokenSchemaError: If the value does not match the schema """ diff --git a/openai_functions/parsers/atomic_type_parser.py b/openai_functions/parsers/atomic_type_parser.py index 65e42bc..f7248d7 100644 --- a/openai_functions/parsers/atomic_type_parser.py +++ b/openai_functions/parsers/atomic_type_parser.py @@ -3,6 +3,7 @@ from abc import abstractmethod from typing import Any, TYPE_CHECKING, Type, TypeVar +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -34,5 +35,5 @@ def argument_schema(self) -> dict[str, JsonType]: def parse_value(self, value: JsonType) -> T: if not isinstance(value, self._type): - raise TypeError(f"Expected {self._type}, got {type(value)}") + raise BrokenSchemaError(value, self.argument_schema) return value diff --git a/openai_functions/parsers/dataclass_parser.py b/openai_functions/parsers/dataclass_parser.py index 2256572..d24b6f0 100644 --- a/openai_functions/parsers/dataclass_parser.py +++ b/openai_functions/parsers/dataclass_parser.py @@ -3,6 +3,7 @@ import dataclasses from typing import Any, ClassVar, Protocol, TYPE_CHECKING, Type +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -23,25 +24,47 @@ class DataclassParser(ArgSchemaParser[IsDataclass]): def can_parse(cls, argtype: Any) -> TypeGuard[Type[IsDataclass]]: return dataclasses.is_dataclass(argtype) + @property + def required_fields(self) -> list[str]: + """All required fields of the dataclass + + Returns: + list[str]: The required fields of the dataclass + """ + return [ + field.name + for field in dataclasses.fields(self.argtype) + if field.default is dataclasses.MISSING + ] + + @property + def fields(self) -> dict[str, JsonType]: + """All fields of the dataclass, with their schemas + + Returns: + dict[str, JsonType]: The fields of the dataclass + """ + return { + field.name: self.parse_rec(field.type).argument_schema + for field in dataclasses.fields(self.argtype) + } + @property def argument_schema(self) -> dict[str, JsonType]: return { "type": "object", "description": self.argtype.__doc__, - "properties": { - field.name: self.parse_rec(field.type).argument_schema - for field in dataclasses.fields(self.argtype) - }, - "required": [ - field.name - for field in dataclasses.fields(self.argtype) - if field.default is dataclasses.MISSING - ], + "properties": self.fields, + "required": self.required_fields, # type: ignore } def parse_value(self, value: JsonType) -> IsDataclass: if not isinstance(value, dict): - raise TypeError(f"Expected dict, got {value}") + raise BrokenSchemaError(value, self.argument_schema) + if not all(field in value for field in self.required_fields): + raise BrokenSchemaError(value, self.argument_schema) + if not all(field in self.fields for field in value): + raise BrokenSchemaError(value, self.argument_schema) return self.argtype( **{ field.name: self.parse_rec(field.type).parse_value(value[field.name]) diff --git a/openai_functions/parsers/dict_parser.py b/openai_functions/parsers/dict_parser.py index 2ad02dc..1fc6166 100644 --- a/openai_functions/parsers/dict_parser.py +++ b/openai_functions/parsers/dict_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any, Dict, TYPE_CHECKING, Type, TypeVar, get_args, get_origin +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -36,7 +37,7 @@ def argument_schema(self) -> Dict[str, JsonType]: def parse_value(self, value: JsonType) -> Dict[str, T]: if not isinstance(value, dict): - raise TypeError(f"Expected dict, got {value}") + raise BrokenSchemaError(value, self.argument_schema) return { k: self.parse_rec(get_args(self.argtype)[1]).parse_value(v) for k, v in value.items() diff --git a/openai_functions/parsers/enum_parser.py b/openai_functions/parsers/enum_parser.py index 9767a1c..92beb30 100644 --- a/openai_functions/parsers/enum_parser.py +++ b/openai_functions/parsers/enum_parser.py @@ -3,6 +3,7 @@ import enum from typing import Any, TYPE_CHECKING, Type, TypeVar +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -33,5 +34,10 @@ def argument_schema(self) -> dict[str, JsonType]: def parse_value(self, value: JsonType) -> T: if not isinstance(value, str): - raise TypeError(f"Expected str, got {value}") + raise BrokenSchemaError(value, self.argument_schema) + if value not in self.argument_schema["enum"]: # type: ignore + # TODO: consider using something other than JsonType for + # all of these, because disabling mypy is definitely + # not the right way to do this + raise BrokenSchemaError(value, self.argument_schema) return self.argtype[value] diff --git a/openai_functions/parsers/float_parser.py b/openai_functions/parsers/float_parser.py index 5d14109..e0eef9b 100644 --- a/openai_functions/parsers/float_parser.py +++ b/openai_functions/parsers/float_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from ..exceptions import BrokenSchemaError from .atomic_type_parser import AtomicParser if TYPE_CHECKING: @@ -16,5 +17,5 @@ class FloatParser(AtomicParser[float]): def parse_value(self, value: JsonType) -> float: if not isinstance(value, (float, int)): - raise TypeError(f"Expected float, got {value}") + raise BrokenSchemaError(value, self.argument_schema) return float(value) diff --git a/openai_functions/parsers/int_parser.py b/openai_functions/parsers/int_parser.py index 34adc80..fa4e1e1 100644 --- a/openai_functions/parsers/int_parser.py +++ b/openai_functions/parsers/int_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from ..exceptions import BrokenSchemaError from .atomic_type_parser import AtomicParser if TYPE_CHECKING: @@ -18,5 +19,5 @@ def parse_value(self, value: JsonType) -> int: if isinstance(value, bool): # This has to happen for historical reasons # bool is a subclass of int, so isinstance(value, int) is True - raise TypeError(f"Expected int, got {value}") + raise BrokenSchemaError(value, self.argument_schema) return super().parse_value(value) diff --git a/openai_functions/parsers/list_parser.py b/openai_functions/parsers/list_parser.py index e5592ff..fc719a5 100644 --- a/openai_functions/parsers/list_parser.py +++ b/openai_functions/parsers/list_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any, List, TYPE_CHECKING, Type, TypeVar, get_args, get_origin +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -30,5 +31,5 @@ def argument_schema(self) -> dict[str, JsonType]: def parse_value(self, value: JsonType) -> List[T]: if not isinstance(value, list): - raise TypeError(f"Expected list, got {value}") + raise BrokenSchemaError(value, self.argument_schema) return [self.parse_rec(get_args(self.argtype)[0]).parse_value(v) for v in value] diff --git a/openai_functions/parsers/none_parser.py b/openai_functions/parsers/none_parser.py index 2ac627e..be72729 100644 --- a/openai_functions/parsers/none_parser.py +++ b/openai_functions/parsers/none_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, Type +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser if TYPE_CHECKING: @@ -22,4 +23,4 @@ def argument_schema(self) -> dict[str, JsonType]: def parse_value(self, value: JsonType) -> None: if value is not None: - raise TypeError(f"Expected None, got {type(value)}") + raise BrokenSchemaError(value, self.argument_schema) diff --git a/openai_functions/parsers/union_parser.py b/openai_functions/parsers/union_parser.py index 6262794..6011509 100644 --- a/openai_functions/parsers/union_parser.py +++ b/openai_functions/parsers/union_parser.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib +from ..exceptions import BrokenSchemaError from .abc import ArgSchemaParser try: @@ -39,6 +40,6 @@ def argument_schema(self) -> dict[str, JsonType]: def parse_value(self, value: JsonType) -> UnionType: for single_type in get_args(self.argtype): - with contextlib.suppress(TypeError): + with contextlib.suppress(BrokenSchemaError): return self.parse_rec(single_type).parse_value(value) - raise TypeError(f"Expected one of {get_args(self.argtype)}, got {value}") + raise BrokenSchemaError(value, self.argument_schema) diff --git a/pyproject.toml b/pyproject.toml index 1b434f5..d732a61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openai-functions" -version = "0.6.11" +version = "0.6.12" description = "Simplifies the usage of OpenAI ChatGPT's function calling by generating the schemas and parsing OpenAI's responses for you." authors = ["rizerphe <44440399+rizerphe@users.noreply.github.com>"] readme = "README.md" diff --git a/tests/test_function_wrapper.py b/tests/test_function_wrapper.py index b0ebc1e..a35e089 100644 --- a/tests/test_function_wrapper.py +++ b/tests/test_function_wrapper.py @@ -6,7 +6,7 @@ import pytest -from openai_functions import FunctionWrapper +from openai_functions import BrokenSchemaError, CannotParseTypeError, FunctionWrapper def test_function_schema_generation_empty(): @@ -177,7 +177,7 @@ def test_function_schema_generation_invalid_parameters(): def test_function(param1: object, param2: str, param3: bool): """Test function docstring.""" - with pytest.raises(TypeError): + with pytest.raises(CannotParseTypeError): FunctionWrapper(test_function).schema @@ -207,7 +207,7 @@ def test_function(param1: Union[int, str, None]): assert function_wrapper({"param1": 1}) == 1 assert function_wrapper({"param1": "test"}) == "test" assert function_wrapper({"param1": None}) is None - with pytest.raises(TypeError): + with pytest.raises(BrokenSchemaError): function_wrapper({"param1": True}) @@ -250,7 +250,7 @@ def test_function(container: Container): }, } function_wrapper({"container": {"item": 1, "priority": 2}}) - with pytest.raises(TypeError): + with pytest.raises(BrokenSchemaError): function_wrapper({"container": 1}) @@ -324,7 +324,7 @@ class Container: def test_function(container: Container): """Test function docstring.""" - with pytest.raises(ValueError): + with pytest.raises(CannotParseTypeError): FunctionWrapper(test_function).schema @@ -355,7 +355,7 @@ def test_function(container: Dict[str, int]): }, } function_wrapper({"container": {"item": 1, "priority": 2}}) - with pytest.raises(TypeError): + with pytest.raises(BrokenSchemaError): function_wrapper({"container": [(1, 2), (3, 4)]}) @@ -384,7 +384,7 @@ def test_function(container: List[Union[int, str]]): }, } function_wrapper({"container": [1, "test"]}) - with pytest.raises(TypeError): + with pytest.raises(BrokenSchemaError): function_wrapper({"container": "test"}) @@ -421,7 +421,7 @@ def test_function(priority: Priority): }, } function_wrapper({"priority": "LOW"}) - with pytest.raises(TypeError): + with pytest.raises(BrokenSchemaError): function_wrapper({"priority": 1})