From de0b1ef6793fb0cee8c1e5602e1693b999a4bff5 Mon Sep 17 00:00:00 2001 From: "synacktra.work@gmail.com" Date: Fri, 6 Sep 2024 15:55:16 +0530 Subject: [PATCH] [chore] resolved bugs and patched version --- Makefile | 4 +- pyproject.toml | 4 +- tests/test_registry.py | 315 ++++++++++++++++++++++++++++------------ tests/test_tool.py | 3 + tool_parse/_registry.py | 10 +- tool_parse/_tool.py | 6 +- tool_parse/compile.py | 4 +- tool_parse/marshal.py | 5 +- 8 files changed, 243 insertions(+), 108 deletions(-) diff --git a/Makefile b/Makefile index f2fde13..87385cd 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: install install: ## Install the poetry environment and install the pre-commit hooks - @echo "🚀 Creating virtual environment using pyenv and poetry" + @echo "🚀 Creating virtual environment using poetry" @poetry install @poetry run pre-commit install @poetry shell @@ -28,7 +28,7 @@ build: clean-build ## Build wheel file using poetry .PHONY: clean-build clean-build: ## clean build artifacts - @rm -rf dist + @python -c "import shutil; shutil.rmtree('dist', ignore_errors=True)" .PHONY: publish publish: ## publish a release to pypi. diff --git a/pyproject.toml b/pyproject.toml index 04148fc..d2e1223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] -name = "tool_parse" -version = "0.2.1" +name = "tool-parse" +version = "0.2.2" description = "Making LLM Tool-Calling Simpler." authors = ["Harsh Verma "] repository = "https://github.com/synacktraa/tool-parse" diff --git a/tests/test_registry.py b/tests/test_registry.py index 49a959b..4a0bfdf 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,12 +1,17 @@ import typing as t +from enum import Enum +from pathlib import Path import pytest from tool_parse import ToolRegistry +from tool_parse.compile import CompileError +from tool_parse.marshal import MarshalError +# Basic registry fixture @pytest.fixture -def registry(): +def basic_registry(): tr = ToolRegistry() @tr.register @@ -45,117 +50,245 @@ class HeroData(t.NamedTuple): tr["HeroData"] = HeroData - yield tr + return tr -def test_tool_registry(registry): - assert len(registry) == 4 +# Complex registry fixture +@pytest.fixture +def complex_registry(): + tr = ToolRegistry() - assert "get_flight_times" in registry - assert "CallApi" in registry - assert "user_info" in registry - assert "HeroData" in registry + @tr.register + def process_data( + text: str, + count: int, + ratio: float, + is_valid: bool, + tags: set[str], + items: list[str], + metadata: dict[str, t.Any], + file_path: Path, + optional_param: t.Optional[int] = None, + ) -> dict: + """ + Process various types of data. + :param text: A string input + :param count: An integer count + :param ratio: A float ratio + :param is_valid: A boolean flag + :param tags: A set of string tags + :param items: A list of string items + :param metadata: A dictionary of metadata + :param file_path: A file path + :param optional_param: An optional integer parameter + """ + return { + "text_length": len(text), + "count_squared": count**2, + "ratio_rounded": round(ratio, 2), + "is_valid": is_valid, + "unique_tags": len(tags), + "items_count": len(items), + "metadata_keys": list(metadata.keys()), + "file_name": file_path.name, + "optional_param": optional_param, + } - with pytest.raises(KeyError): - registry["some_tool"] + class ColorEnum(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + @tr.register + def color_brightness(color: ColorEnum, brightness: t.Literal["light", "dark"]) -> str: + """ + Get color brightness description. + :param color: The color enum + :param brightness: The brightness level + """ + return f"{brightness} {color.value}" -def test_registry_marshal_method(registry): - tools = registry.marshal("base") + class UserProfile(t.TypedDict): + name: str + age: int + hobbies: t.List[str] - assert tools is not None - assert not isinstance(tools, str) + @tr.register + def create_user_profile(profile: UserProfile) -> str: + """ + Create a user profile. + :param profile: The user profile data + """ + return f"{profile['name']} ({profile['age']}) likes {', '.join(profile['hobbies'])}" - assert tools[0]["type"] == "function" - assert tools[0]["function"]["name"] == "get_flight_times" - assert tools[0]["function"]["description"] == "Get flight times." - assert tools[0]["function"]["parameters"]["type"] == "object" - assert tools[0]["function"]["parameters"]["required"] == ["departure", "arrival"] - assert tools[0]["function"]["parameters"]["properties"]["departure"]["type"] == "string" - assert ( - tools[0]["function"]["parameters"]["properties"]["departure"]["description"] - == "Departure location code" - ) - assert tools[0]["function"]["parameters"]["properties"]["arrival"]["type"] == "string" - assert ( - tools[0]["function"]["parameters"]["properties"]["arrival"]["description"] - == "Arrival location code" - ) + class BookInfo(t.NamedTuple): + title: str + author: str + year: int - assert tools[1]["type"] == "function" - assert tools[1]["function"]["name"] == "CallApi" - assert tools[1]["function"]["description"] == "Call the API." - assert tools[1]["function"]["parameters"]["type"] == "object" - assert tools[1]["function"]["parameters"]["required"] == ["host", "port"] - assert tools[1]["function"]["parameters"]["properties"]["host"]["type"] == "string" - assert tools[1]["function"]["parameters"]["properties"]["host"]["description"] == "Target host." - assert tools[1]["function"]["parameters"]["properties"]["port"]["type"] == "integer" - assert ( - tools[1]["function"]["parameters"]["properties"]["port"]["description"] - == "Port number to request." - ) + @tr.register + def format_book_info(book: BookInfo) -> str: + """ + Format book information. + :param book: The book information + """ + return f"{book.title} by {book.author} ({book.year})" - assert tools[2]["type"] == "function" - assert tools[2]["function"]["name"] == "user_info" - assert tools[2]["function"]["description"] == "Information of the user." - assert tools[2]["function"]["parameters"]["type"] == "object" - assert tools[2]["function"]["parameters"]["required"] == ["name"] - assert tools[2]["function"]["parameters"]["properties"]["name"]["type"] == "string" - assert tools[2]["function"]["parameters"]["properties"]["role"]["type"] == "string" - assert tools[2]["function"]["parameters"]["properties"]["role"]["enum"] == ["admin", "tester"] - - assert tools[3]["type"] == "function" - assert tools[3]["function"]["name"] == "HeroData" - assert tools[3]["function"]["parameters"]["type"] == "object" - assert tools[3]["function"]["parameters"]["required"] == ["info"] - assert tools[3]["function"]["parameters"]["properties"]["info"]["type"] == "object" - assert ( - tools[3]["function"]["parameters"]["properties"]["info"]["properties"]["name"]["type"] - == "string" + return tr + + +# Tests for basic registry +def test_basic_registry_content(basic_registry): + assert len(basic_registry) == 4 + assert all( + tool in basic_registry for tool in ["get_flight_times", "CallApi", "user_info", "HeroData"] ) + + with pytest.raises(KeyError): + basic_registry["non_existent_tool"] + + +def test_basic_registry_marshal(basic_registry): + tools = basic_registry.marshal("base") + assert len(tools) == 4 + + flight_tool = next(tool for tool in tools if tool["function"]["name"] == "get_flight_times") + assert flight_tool["function"]["parameters"]["required"] == ["departure", "arrival"] + + user_info_tool = next(tool for tool in tools if tool["function"]["name"] == "user_info") + assert user_info_tool["function"]["parameters"]["properties"]["role"]["enum"] == [ + "admin", + "tester", + ] + + +def test_basic_registry_compile(basic_registry): assert ( - tools[3]["function"]["parameters"]["properties"]["info"]["properties"]["age"]["type"] - == "integer" + basic_registry.compile( + name="get_flight_times", arguments={"departure": "NYC", "arrival": "JFK"} + ) + == "2 hours" ) - assert tools[3]["function"]["parameters"]["properties"]["info"]["required"] == ["name"] - assert tools[3]["function"]["parameters"]["properties"]["powers"]["type"] == "array" - assert tools[3]["function"]["parameters"]["properties"]["powers"]["items"]["type"] == "string" + user_info = basic_registry.compile(name="user_info", arguments={"name": "Alice"}) + assert user_info == {"name": "Alice", "role": "tester"} -def test_registry_compile_method(registry): - get_flight_times_output = registry.compile( - name="get_flight_times", arguments={"departure": "NYC", "arrival": "JFK"} - ) - assert get_flight_times_output == "2 hours" - call_api_output = registry.compile( - name="CallApi", arguments={"host": "localhost", "port": 8080} +# Tests for complex registry +def test_complex_registry_process_data(complex_registry): + result = complex_registry.compile( + name="process_data", + arguments={ + "text": "Hello, World!", + "count": 5, + "ratio": 3.14159, + "is_valid": True, + "tags": ["python", "testing", "testing"], + "items": ["item1", "item2", "item3"], + "metadata": {"key1": "value1", "key2": "value2"}, + "file_path": "../test.txt", + "optional_param": 42, + }, ) - assert call_api_output.get("status") == "ok" + assert result["text_length"] == 13 + assert result["count_squared"] == 25 + assert result["ratio_rounded"] == 3.14 + assert result["unique_tags"] == 2 + assert result["file_name"] == "test.txt" + - UserInfo_1_output = registry.compile( - name="user_info", arguments={"name": "Andrej", "role": "admin"} +def test_complex_registry_enum_and_literal(complex_registry): + result = complex_registry.compile( + name="color_brightness", arguments={"color": "RED", "brightness": "light"} ) - assert UserInfo_1_output["name"] == "Andrej" - assert UserInfo_1_output["role"] == "admin" + assert result == "light red" - UserInfo_2_output = registry.compile("user_info(name='Synacktra')") - assert UserInfo_2_output["name"] == "Synacktra" - assert UserInfo_2_output["role"] == "tester" + with pytest.raises(CompileError): + complex_registry.compile( + name="color_brightness", arguments={"color": "YELLOW", "brightness": "light"} + ) - HeroData_1_output = registry.compile( - name="HeroData", + with pytest.raises(CompileError): + complex_registry.compile( + name="color_brightness", arguments={"color": "RED", "brightness": "medium"} + ) + + +def test_complex_registry_typed_dict(complex_registry): + result = complex_registry.compile( + name="create_user_profile", arguments={ - "info": {"name": "homelander", "age": "1"}, - "powers": ["bullying", "laser beam"], + "profile": {"name": "Alice", "age": 30, "hobbies": ["reading", "hiking", "photography"]} }, ) - assert HeroData_1_output[0]["name"] == "homelander" - assert HeroData_1_output.info["age"] == 1 - assert HeroData_1_output.powers[0] == "bullying" - assert HeroData_1_output.powers[1] == "laser beam" - - HeroData_2_output = registry.compile(name="HeroData", arguments={"info": {"name": "Man"}}) - assert HeroData_2_output.info["name"] == "Man" - assert HeroData_2_output[0]["age"] is None - assert HeroData_2_output.powers is None + assert result == "Alice (30) likes reading, hiking, photography" + + with pytest.raises(CompileError): + complex_registry.compile( + name="create_user_profile", + arguments={"profile": {"name": "Bob", "hobbies": ["coding"]}}, + ) + + +def test_complex_registry_named_tuple(complex_registry): + result = complex_registry.compile( + name="format_book_info", + arguments={"book": {"title": "1984", "author": "George Orwell", "year": 1949}}, + ) + assert result == "1984 by George Orwell (1949)" + + +# Tests for MarshalError +def test_marshal_error_unsupported_type(): + tr = ToolRegistry() + + @tr.register + def process_bytes(data: bytes) -> str: + """Process byte data""" + return data.decode() + + with pytest.raises(MarshalError): + _ = tr.marshal("base") + + +def test_marshal_error_complex_unsupported_type(): + tr = ToolRegistry() + + @tr.register + def process_complex_data(data: t.List[bytes]) -> str: + """Process complex data with unsupported type""" + return str(len(data)) + + with pytest.raises(MarshalError): + _ = tr.marshal("base") + + +# Additional tests for edge cases +def test_empty_registry(): + tr = ToolRegistry() + assert len(tr) == 0 + assert tr.marshal("base") is None + + +def test_registry_addition(): + tr1 = ToolRegistry() + tr2 = ToolRegistry() + + @tr1.register + def func1(): + pass + + @tr2.register + def func2(): + pass + + combined = tr1 + tr2 + assert len(combined) == 2 + assert "func1" in combined and "func2" in combined + + +def test_marshal_as_json(complex_registry): + json_output = complex_registry.marshal(as_json=True) + assert isinstance(json_output, str) + assert "process_data" in json_output + assert "color_brightness" in json_output diff --git a/tests/test_tool.py b/tests/test_tool.py index 356e626..bbca7b8 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -31,10 +31,13 @@ def get_flight_times(departure: str, arrival: str) -> str: == "Arrival location code" ) + # passing dictionary as arguments assert get_flight_times.compile(arguments={"departure": "NYC", "arrival": "JFK"}) == "2 hours" + # passing json as arguments assert get_flight_times.compile(arguments='{"departure": "NYC", "arrival": "JFK"}') == "2 hours" + # passing call expression assert get_flight_times.compile("get_flight_times(departure='NYC', arrival='JFK')") == "2 hours" with pytest.raises(ValueError): diff --git a/tool_parse/_registry.py b/tool_parse/_registry.py index c4ff547..fbdbab8 100644 --- a/tool_parse/_registry.py +++ b/tool_parse/_registry.py @@ -174,7 +174,7 @@ def marshal( :param as_json: If `True`, schema is returned as JSON object. :param persist_at: Path to `.json` file to persist schema. """ - if not self: + if not self.__entries: return None schema = [] @@ -199,7 +199,7 @@ def compile(self, __expression: str) -> t.Any: """ @t.overload - def compile(self, *, name: str, arguments: t.Optional[str | dict[str, t.Any]] = None) -> t.Any: + def compile(self, *, name: str, arguments: str | dict[str, t.Any]) -> t.Any: """ Compile a tool from call metadata @@ -214,12 +214,12 @@ def compile( name: t.Optional[str] = None, arguments: t.Optional[str | dict[str, t.Any]] = None, ): - if not __expression and not name: - raise ValueError("Either tool expression or name & arguments required.") - if __expression: name, arguments = compile.parse_expression(__expression) + if name is None and arguments is None: + raise ValueError("Either tool expression or name & arguments required.") + if (entry := self.__entries.get(name)) is None: raise NotRegisteredError(f"{name!r} tool has not been registered") diff --git a/tool_parse/_tool.py b/tool_parse/_tool.py index 2baaa11..b95defe 100644 --- a/tool_parse/_tool.py +++ b/tool_parse/_tool.py @@ -51,12 +51,12 @@ def compile( *, arguments: t.Optional[str | dict[str, t.Any]] = None, ): - if not __expression and not arguments: - raise ValueError("Either tool call expression or arguments required.") - if __expression: name, arguments = compile.parse_expression(__expression) if name != self.name: raise ValueError(f"Expected call expression for tool {self.name!r}") + if arguments is None: + raise ValueError("Either tool call expression or arguments required.") + return compile.compile_object(self.__obj, arguments=arguments or {}) diff --git a/tool_parse/compile.py b/tool_parse/compile.py index c2781c5..0a0fc1a 100644 --- a/tool_parse/compile.py +++ b/tool_parse/compile.py @@ -173,10 +173,10 @@ def compile_value( # noqa: C901 ) return raw_value, is_optional - if annot in (list, t.List): + if annot in (list, set, t.List, t.Set): if not isinstance(raw_value, list): raise CompileError(f"Expected list value, {rest_err}") - return [compile_value(args[0], e)[0] for e in raw_value], is_optional + return annot(compile_value(args[0], e)[0] for e in raw_value), is_optional if issubclass(annot, Path): if not isinstance(raw_value, str): diff --git a/tool_parse/marshal.py b/tool_parse/marshal.py index b942ffd..a0a17ca 100644 --- a/tool_parse/marshal.py +++ b/tool_parse/marshal.py @@ -117,15 +117,14 @@ def marshal_annotation(__annotation: type | t.ForwardRef) -> tuple[dict[str, t.A annot, args, is_optional = resolve_annotation(__annotation) if args: - if annot in (list, t.List): + if annot in (list, set, t.List, t.Set): return {"type": "array", "items": marshal_annotation(args[0])[0]}, is_optional if annot is t.Literal: arg_types = list({type(e) for e in args}) if len(arg_types) != 1: raise MarshalError("Literal args must be of same type.") - arg_type = arg_types[0] - if arg_type not in (str, int, float, bool): + if (arg_type := arg_types[0]) not in (str, int, float, bool): raise MarshalError( f"{getattr(arg_type, '__name__', arg_type)!r} type is not supported in typing.Literal." )