Skip to content

Commit

Permalink
chore: tool fix
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 22, 2024
1 parent ffbd23e commit 4ca48bd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
10 changes: 6 additions & 4 deletions src/llmling/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BaseTool(ABC):
parameters_schema: ClassVar[dict[str, Any]]

@classmethod
def get_schema(cls) -> py2openai.ToolSchema:
def get_schema(cls) -> py2openai.OpenAIFunctionTool:
"""Get the tool's schema for LLM function calling."""
return py2openai.create_schema(cls.execute).model_dump_openai()

Expand Down Expand Up @@ -72,12 +72,14 @@ def func(self) -> Callable[..., Any]:
self._func = calling.import_callable(self.import_path)
return self._func

def get_schema(self) -> py2openai.ToolSchema:
def get_schema(self) -> py2openai.OpenAIFunctionTool:
"""Generate schema from function signature."""
schema_dict = py2openai.create_schema(self.func).model_dump_openai()
# Override name and description
schema_dict["name"] = self.name or schema_dict["name"]
schema_dict["description"] = self._description or schema_dict["description"]
schema_dict["function"]["name"] = self.name or schema_dict["function"]["name"]
schema_dict["function"]["description"] = (
self._description or schema_dict["function"]["description"]
)
return schema_dict

async def execute(self, **params: Any) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions src/llmling/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from py2openai import ToolSchema
from py2openai import OpenAIFunctionTool

from llmling.tools import exceptions

Expand Down Expand Up @@ -49,7 +49,7 @@ def _validate_item(self, item: Any) -> BaseTool | DynamicTool:
msg = f"Invalid tool type: {type(item)}"
raise ToolError(msg)

def get_schema(self, name: str) -> ToolSchema:
def get_schema(self, name: str) -> OpenAIFunctionTool:
"""Get schema for a tool."""
tool = self.get(name)
return tool.get_schema()
Expand Down
16 changes: 8 additions & 8 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def test_schema_generation(self) -> None:
tool = DynamicTool(EXAMPLE_IMPORT)
schema = tool.get_schema()

assert schema["name"] == "example_tool"
assert "text" in schema["parameters"]["properties"]
assert "repeat" in schema["parameters"]["properties"]
assert schema["parameters"]["required"] == ["text"]
assert schema["function"]["name"] == "example_tool"
assert "text" in schema["function"]["parameters"]["properties"]
assert "repeat" in schema["function"]["parameters"]["properties"]
assert schema["function"]["parameters"]["required"] == ["text"]

@pytest.mark.asyncio
async def test_execution(self) -> None:
Expand Down Expand Up @@ -165,9 +165,9 @@ def test_schema_generation(self, registry: ToolRegistry) -> None:
registry["analyze_ast"] = ANALYZE_IMPORT
schema = registry.get_schema("analyze_ast")

assert "code" in schema["parameters"]["properties"]
assert schema["parameters"]["required"] == ["code"]
assert "Analyze Python code AST" in schema["description"]
assert "code" in schema["function"]["parameters"]["properties"]
assert schema["function"]["parameters"]["required"] == ["code"]
assert "Analyze Python code AST" in schema["function"]["description"]


# Integration tests
Expand All @@ -180,7 +180,7 @@ async def test_tool_integration() -> None:

# Get schema
schema = registry.get_schema("analyze")
assert schema["name"] == "analyze_ast"
assert schema["function"]["name"] == "analyze_ast"
# Execute tool
code = """
class TestClass:
Expand Down

0 comments on commit 4ca48bd

Please sign in to comment.