Skip to content

Commit

Permalink
Add proper invalid response handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rizerphe committed Jul 3, 2023
1 parent 35725f6 commit cd54db6
Show file tree
Hide file tree
Showing 16 changed files with 144 additions and 38 deletions.
6 changes: 6 additions & 0 deletions openai_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""ChatGPT function calling based on function docstrings."""
from .conversation import Conversation
from .exceptions import (
BrokenSchemaError,
CannotParseTypeError,
FunctionNotFoundError,
InvalidJsonError,
NonSerializableOutputError,
OpenAIFunctionsError,
)
Expand All @@ -23,7 +26,10 @@

__all__ = [
"Conversation",
"BrokenSchemaError",
"CannotParseTypeError",
"FunctionNotFoundError",
"InvalidJsonError",
"NonSerializableOutputError",
"OpenAIFunctionsError",
"BasicFunctionSet",
Expand Down
48 changes: 48 additions & 0 deletions openai_functions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import Any

from .json_type import JsonType


class OpenAIFunctionsError(Exception):
"""The base exception for all OpenAI Functions errors."""
Expand All @@ -12,6 +14,19 @@ class FunctionNotFoundError(OpenAIFunctionsError):
"""The function was not found in the given skillset."""


class CannotParseTypeError(OpenAIFunctionsError):
"""This type of the argument could not be parsed."""

def __init__(self, argtype: Any) -> None:
"""Initialize the CannotParseTypeError.
Args:
argtype (Any): The type that could not be parsed
"""
super().__init__(f"Cannot parse type {argtype}")
self.argtype = argtype


class NonSerializableOutputError(OpenAIFunctionsError):
"""The function returned a non-serializable output."""

Expand All @@ -26,3 +41,36 @@ def __init__(self, result: Any) -> None:
"Set serialize=False to use str() instead."
)
self.result = result


class InvalidJsonError(OpenAIFunctionsError):
"""OpenAI returned invalid JSON for the arguments."""

def __init__(self, response: str) -> None:
"""Initialize the InvalidJsonError.
Args:
response (str): The response that was not valid JSON
"""
super().__init__(
f"OpenAI returned invalid (perhaps incomplete) JSON: {response}"
)
self.response = response


class BrokenSchemaError(OpenAIFunctionsError):
"""The OpenAI response did not match the schema."""

def __init__(self, response: JsonType, schema: JsonType) -> None:
"""Initialize the BrokenSchemaError.
Args:
response (JsonType): The response that did not match the schema
schema (JsonType): The schema that the response did not match
"""
super().__init__(
"OpenAI returned a response that did not match the schema: "
f"{response!r} does not match {schema}"
)
self.response = response
self.schema = schema
8 changes: 6 additions & 2 deletions openai_functions/functions/basic_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import TYPE_CHECKING

from ..exceptions import FunctionNotFoundError
from ..exceptions import FunctionNotFoundError, InvalidJsonError
from .functions import FunctionResult, OpenAIFunction, RawFunctionResult
from .sets import MutableFunctionSet

Expand Down Expand Up @@ -49,7 +49,11 @@ def run_function(self, input_data: FunctionCall) -> FunctionResult:
FunctionNotFoundError: If the function is not found
"""
function = self.find_function(input_data["name"])
result = self.get_function_result(function, json.loads(input_data["arguments"]))
try:
arguments = json.loads(input_data["arguments"])
except json.decoder.JSONDecodeError as e:
raise InvalidJsonError(input_data["arguments"]) from e
result = self.get_function_result(function, arguments)
return FunctionResult(
function.name, result, function.remove_call, function.interpret_as_response
)
Expand Down
19 changes: 13 additions & 6 deletions openai_functions/functions/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from docstring_parser import Docstring, parse

from ..exceptions import BrokenSchemaError, CannotParseTypeError
from ..parsers import ArgSchemaParser, defargparsers

if TYPE_CHECKING:
Expand Down Expand Up @@ -221,7 +222,7 @@ def parse_argument(self, argument: inspect.Parameter) -> ArgSchemaParser:
argument (inspect.Parameter): The argument to parse
Raises:
TypeError: If the argument cannot be parsed
CannotParseTypeError: If the argument cannot be parsed
Returns:
ArgSchemaParser: The parser for the argument
Expand All @@ -232,7 +233,7 @@ def parse_argument(self, argument: inspect.Parameter) -> ArgSchemaParser:
for parser in self.parsers:
if parser.can_parse(argument.annotation):
return parser(argument.annotation, self.parsers)
raise TypeError(f"Cannot parse argument {argument}")
raise CannotParseTypeError(argument.annotation)

def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, Any]:
"""Parse arguments
Expand All @@ -243,10 +244,16 @@ def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, An
Returns:
OrderedDict[str, Any]: The parsed arguments
"""
return OrderedDict(
(name, self.argument_parsers[name].parse_value(value))
for name, value in arguments.items()
)
argument_parsers = self.argument_parsers
if not all(name in arguments for name in argument_parsers):
raise BrokenSchemaError(arguments, self.arguments_schema)
try:
return OrderedDict(
(name, argument_parsers[name].parse_value(value))
for name, value in arguments.items()
)
except KeyError as e:
raise BrokenSchemaError(arguments, self.arguments_schema) from e

def __call__(self, arguments: dict[str, JsonType]) -> Any:
"""Call the wrapped function
Expand Down
9 changes: 7 additions & 2 deletions openai_functions/parsers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TYPE_CHECKING, Type, TypeVar

from ..exceptions import CannotParseTypeError

if TYPE_CHECKING:
from ..json_type import JsonType
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -34,12 +36,12 @@ def parse_rec(self, argtype: Type[S]) -> ArgSchemaParser[S]:
ArgSchemaParser[S]: The parser for the type
Raises:
ValueError: If the type cannot be parsed
CannotParseTypeError: If the type cannot be parsed
"""
for parser in self.rec_parsers:
if parser.can_parse(argtype):
return parser(argtype, self.rec_parsers)
raise ValueError(f"Cannot parse type {argtype}")
raise CannotParseTypeError(argtype)

@classmethod
@abstractmethod
Expand All @@ -61,4 +63,7 @@ def parse_value(self, value: JsonType) -> T:
Args:
value (JsonType): The value to parse
Raises:
BrokenSchemaError: If the value does not match the schema
"""
3 changes: 2 additions & 1 deletion openai_functions/parsers/atomic_type_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import abstractmethod
from typing import Any, TYPE_CHECKING, Type, TypeVar

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand Down Expand Up @@ -34,5 +35,5 @@ def argument_schema(self) -> dict[str, JsonType]:

def parse_value(self, value: JsonType) -> T:
if not isinstance(value, self._type):
raise TypeError(f"Expected {self._type}, got {type(value)}")
raise BrokenSchemaError(value, self.argument_schema)
return value
43 changes: 33 additions & 10 deletions openai_functions/parsers/dataclass_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
from typing import Any, ClassVar, Protocol, TYPE_CHECKING, Type

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand All @@ -23,25 +24,47 @@ class DataclassParser(ArgSchemaParser[IsDataclass]):
def can_parse(cls, argtype: Any) -> TypeGuard[Type[IsDataclass]]:
return dataclasses.is_dataclass(argtype)

@property
def required_fields(self) -> list[str]:
"""All required fields of the dataclass
Returns:
list[str]: The required fields of the dataclass
"""
return [
field.name
for field in dataclasses.fields(self.argtype)
if field.default is dataclasses.MISSING
]

@property
def fields(self) -> dict[str, JsonType]:
"""All fields of the dataclass, with their schemas
Returns:
dict[str, JsonType]: The fields of the dataclass
"""
return {
field.name: self.parse_rec(field.type).argument_schema
for field in dataclasses.fields(self.argtype)
}

@property
def argument_schema(self) -> dict[str, JsonType]:
return {
"type": "object",
"description": self.argtype.__doc__,
"properties": {
field.name: self.parse_rec(field.type).argument_schema
for field in dataclasses.fields(self.argtype)
},
"required": [
field.name
for field in dataclasses.fields(self.argtype)
if field.default is dataclasses.MISSING
],
"properties": self.fields,
"required": self.required_fields, # type: ignore
}

def parse_value(self, value: JsonType) -> IsDataclass:
if not isinstance(value, dict):
raise TypeError(f"Expected dict, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
if not all(field in value for field in self.required_fields):
raise BrokenSchemaError(value, self.argument_schema)
if not all(field in self.fields for field in value):
raise BrokenSchemaError(value, self.argument_schema)
return self.argtype(
**{
field.name: self.parse_rec(field.type).parse_value(value[field.name])
Expand Down
3 changes: 2 additions & 1 deletion openai_functions/parsers/dict_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, TYPE_CHECKING, Type, TypeVar, get_args, get_origin

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand Down Expand Up @@ -36,7 +37,7 @@ def argument_schema(self) -> Dict[str, JsonType]:

def parse_value(self, value: JsonType) -> Dict[str, T]:
if not isinstance(value, dict):
raise TypeError(f"Expected dict, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
return {
k: self.parse_rec(get_args(self.argtype)[1]).parse_value(v)
for k, v in value.items()
Expand Down
8 changes: 7 additions & 1 deletion openai_functions/parsers/enum_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
from typing import Any, TYPE_CHECKING, Type, TypeVar

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand Down Expand Up @@ -33,5 +34,10 @@ def argument_schema(self) -> dict[str, JsonType]:

def parse_value(self, value: JsonType) -> T:
if not isinstance(value, str):
raise TypeError(f"Expected str, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
if value not in self.argument_schema["enum"]: # type: ignore
# TODO: consider using something other than JsonType for
# all of these, because disabling mypy is definitely
# not the right way to do this
raise BrokenSchemaError(value, self.argument_schema)
return self.argtype[value]
3 changes: 2 additions & 1 deletion openai_functions/parsers/float_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from ..exceptions import BrokenSchemaError
from .atomic_type_parser import AtomicParser

if TYPE_CHECKING:
Expand All @@ -16,5 +17,5 @@ class FloatParser(AtomicParser[float]):

def parse_value(self, value: JsonType) -> float:
if not isinstance(value, (float, int)):
raise TypeError(f"Expected float, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
return float(value)
3 changes: 2 additions & 1 deletion openai_functions/parsers/int_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from ..exceptions import BrokenSchemaError
from .atomic_type_parser import AtomicParser

if TYPE_CHECKING:
Expand All @@ -18,5 +19,5 @@ def parse_value(self, value: JsonType) -> int:
if isinstance(value, bool):
# This has to happen for historical reasons
# bool is a subclass of int, so isinstance(value, int) is True
raise TypeError(f"Expected int, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
return super().parse_value(value)
3 changes: 2 additions & 1 deletion openai_functions/parsers/list_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any, List, TYPE_CHECKING, Type, TypeVar, get_args, get_origin

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand Down Expand Up @@ -30,5 +31,5 @@ def argument_schema(self) -> dict[str, JsonType]:

def parse_value(self, value: JsonType) -> List[T]:
if not isinstance(value, list):
raise TypeError(f"Expected list, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
return [self.parse_rec(get_args(self.argtype)[0]).parse_value(v) for v in value]
3 changes: 2 additions & 1 deletion openai_functions/parsers/none_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Type

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

if TYPE_CHECKING:
Expand All @@ -22,4 +23,4 @@ def argument_schema(self) -> dict[str, JsonType]:

def parse_value(self, value: JsonType) -> None:
if value is not None:
raise TypeError(f"Expected None, got {type(value)}")
raise BrokenSchemaError(value, self.argument_schema)
5 changes: 3 additions & 2 deletions openai_functions/parsers/union_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
import contextlib

from ..exceptions import BrokenSchemaError
from .abc import ArgSchemaParser

try:
Expand Down Expand Up @@ -39,6 +40,6 @@ def argument_schema(self) -> dict[str, JsonType]:

def parse_value(self, value: JsonType) -> UnionType:
for single_type in get_args(self.argtype):
with contextlib.suppress(TypeError):
with contextlib.suppress(BrokenSchemaError):
return self.parse_rec(single_type).parse_value(value)
raise TypeError(f"Expected one of {get_args(self.argtype)}, got {value}")
raise BrokenSchemaError(value, self.argument_schema)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openai-functions"
version = "0.6.11"
version = "0.6.12"
description = "Simplifies the usage of OpenAI ChatGPT's function calling by generating the schemas and parsing OpenAI's responses for you."
authors = ["rizerphe <44440399+rizerphe@users.noreply.github.com>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit cd54db6

Please sign in to comment.