From 47e39d4d1ffa5460f1eed485e93ad5c8fcaa7c59 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 00:40:25 +0000 Subject: [PATCH] fix: improve OpenAPI schema validation and type safety Co-Authored-By: Chaoyu Yang --- src/_bentoml_sdk/service/openapi.py | 37 ++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/_bentoml_sdk/service/openapi.py b/src/_bentoml_sdk/service/openapi.py index 673a9b2908a..56ac8039f03 100644 --- a/src/_bentoml_sdk/service/openapi.py +++ b/src/_bentoml_sdk/service/openapi.py @@ -33,9 +33,9 @@ merger = Merger( # merge dicts recursively [(dict, "merge")], - # merge lists by concatenating - ["append"], - # override other types + # override all other types (including lists) + ["override"], + # override conflicting types ["override"], ) @@ -212,11 +212,22 @@ def _get_api_routes(svc: Service[t.Any]) -> dict[str, PathItem]: raise TypeError( f"Content for {content_type} must be a dictionary" ) - if "schema" in content: - schema = content["schema"] - if not isinstance(schema, dict): - raise TypeError("Schema must be a dictionary") - if "type" in schema and schema["type"] not in { + try: + content_dict: dict[str, t.Any] = { + "schema": content.get("schema"), + "example": content.get("example"), + "examples": content.get("examples"), + "encoding": content.get("encoding"), + } + content_obj = MediaType(**content_dict) + except (TypeError, ValueError) as e: + raise TypeError(f"Invalid content format: {e}") + + if content_obj.schema is not None: + schema = content_obj.schema + if isinstance(schema, Reference): + continue + valid_types = { "object", "array", "string", @@ -224,10 +235,14 @@ def _get_api_routes(svc: Service[t.Any]) -> dict[str, PathItem]: "integer", "boolean", "null", - }: + } + if ( + schema.type is not None + and schema.type not in valid_types + ): raise ValueError( - f"Invalid schema type: {schema['type']}. " - "Must be one of: object, array, string, number, integer, boolean, null" + f"Invalid schema type: {schema.type}. " + f"Must be one of: {', '.join(sorted(valid_types))}" ) elif field == "tags": if not isinstance(value, list):