diff --git a/src/llmling/tools/openapi.py b/src/llmling/tools/openapi.py index b63b8aa..97078a9 100644 --- a/src/llmling/tools/openapi.py +++ b/src/llmling/tools/openapi.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from jsonschema_path.typing import Schema + logger = get_logger(__name__) T = TypeVar("T") @@ -50,11 +52,11 @@ def __init__( self.base_url = base_url self.headers = headers or {} self._client = httpx.AsyncClient(base_url=self.base_url, headers=self.headers) - self._spec: dict[str, Any] = {} + self._spec: Schema = {} self._schemas: dict[str, Any] = {} self._operations: dict[str, Any] = {} - def _store_spec(self, spec_data: dict[str, Any]) -> None: + def _store_spec(self, spec_data: Schema) -> None: """Helper to store and parse spec data.""" self._spec = spec_data self._schemas = self._spec.get("components", {}).get("schemas", {}) @@ -70,7 +72,7 @@ def _ensure_loaded(self) -> None: spec_data = self._load_spec() self._store_spec(spec_data) - def _load_spec(self) -> dict[str, Any]: + def _load_spec(self) -> Schema: """Load OpenAPI specification.""" try: if self.spec_url.startswith(("http://", "https://")): diff --git a/tests/test_openapi_toolsets.py b/tests/test_openapi_toolsets.py index 8b6fd1e..df230b1 100644 --- a/tests/test_openapi_toolsets.py +++ b/tests/test_openapi_toolsets.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING from openapi_spec_validator import validate from openapi_spec_validator.exceptions import OpenAPISpecValidatorError @@ -9,8 +10,12 @@ from llmling.tools.openapi import OpenAPITools +if TYPE_CHECKING: + from jsonschema_path.typing import Schema + + BASE_URL = "https://api.example.com" -PETSTORE_SPEC = { +PETSTORE_SPEC: Schema = { "openapi": "3.0.0", "info": {"title": "Pet Store API", "version": "1.0.0"}, "paths": {