From 4ca48bdb536683d69b27dc4a2a69e2316faeac04 Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Sat, 23 Nov 2024 00:17:33 +0100 Subject: [PATCH] chore: tool fix --- src/llmling/tools/base.py | 10 ++++++---- src/llmling/tools/registry.py | 4 ++-- tests/test_tools.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/llmling/tools/base.py b/src/llmling/tools/base.py index 5e38a11..f5d12df 100644 --- a/src/llmling/tools/base.py +++ b/src/llmling/tools/base.py @@ -24,7 +24,7 @@ class BaseTool(ABC): parameters_schema: ClassVar[dict[str, Any]] @classmethod - def get_schema(cls) -> py2openai.ToolSchema: + def get_schema(cls) -> py2openai.OpenAIFunctionTool: """Get the tool's schema for LLM function calling.""" return py2openai.create_schema(cls.execute).model_dump_openai() @@ -72,12 +72,14 @@ def func(self) -> Callable[..., Any]: self._func = calling.import_callable(self.import_path) return self._func - def get_schema(self) -> py2openai.ToolSchema: + def get_schema(self) -> py2openai.OpenAIFunctionTool: """Generate schema from function signature.""" schema_dict = py2openai.create_schema(self.func).model_dump_openai() # Override name and description - schema_dict["name"] = self.name or schema_dict["name"] - schema_dict["description"] = self._description or schema_dict["description"] + schema_dict["function"]["name"] = self.name or schema_dict["function"]["name"] + schema_dict["function"]["description"] = ( + self._description or schema_dict["function"]["description"] + ) return schema_dict async def execute(self, **params: Any) -> Any: diff --git a/src/llmling/tools/registry.py b/src/llmling/tools/registry.py index 0049138..b621b70 100644 --- a/src/llmling/tools/registry.py +++ b/src/llmling/tools/registry.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from py2openai import ToolSchema + from py2openai import OpenAIFunctionTool from llmling.tools import exceptions @@ -49,7 +49,7 @@ def _validate_item(self, item: Any) -> BaseTool | DynamicTool: msg = f"Invalid tool type: {type(item)}" raise ToolError(msg) - def get_schema(self, name: str) -> ToolSchema: + def get_schema(self, name: str) -> OpenAIFunctionTool: """Get schema for a tool.""" tool = self.get(name) return tool.get_schema() diff --git a/tests/test_tools.py b/tests/test_tools.py index bd501f2..f448011 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -92,10 +92,10 @@ def test_schema_generation(self) -> None: tool = DynamicTool(EXAMPLE_IMPORT) schema = tool.get_schema() - assert schema["name"] == "example_tool" - assert "text" in schema["parameters"]["properties"] - assert "repeat" in schema["parameters"]["properties"] - assert schema["parameters"]["required"] == ["text"] + assert schema["function"]["name"] == "example_tool" + assert "text" in schema["function"]["parameters"]["properties"] + assert "repeat" in schema["function"]["parameters"]["properties"] + assert schema["function"]["parameters"]["required"] == ["text"] @pytest.mark.asyncio async def test_execution(self) -> None: @@ -165,9 +165,9 @@ def test_schema_generation(self, registry: ToolRegistry) -> None: registry["analyze_ast"] = ANALYZE_IMPORT schema = registry.get_schema("analyze_ast") - assert "code" in schema["parameters"]["properties"] - assert schema["parameters"]["required"] == ["code"] - assert "Analyze Python code AST" in schema["description"] + assert "code" in schema["function"]["parameters"]["properties"] + assert schema["function"]["parameters"]["required"] == ["code"] + assert "Analyze Python code AST" in schema["function"]["description"] # Integration tests @@ -180,7 +180,7 @@ async def test_tool_integration() -> None: # Get schema schema = registry.get_schema("analyze") - assert schema["name"] == "analyze_ast" + assert schema["function"]["name"] == "analyze_ast" # Execute tool code = """ class TestClass: