Skip to content

Commit

Permalink
[chore] resolved bugs and patched version
Browse files Browse the repository at this point in the history
  • Loading branch information
synacktraa committed Sep 6, 2024
1 parent d5aee88 commit de0b1ef
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 108 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: install
install: ## Install the poetry environment and install the pre-commit hooks
@echo "🚀 Creating virtual environment using pyenv and poetry"
@echo "🚀 Creating virtual environment using poetry"
@poetry install
@poetry run pre-commit install
@poetry shell
Expand Down Expand Up @@ -28,7 +28,7 @@ build: clean-build ## Build wheel file using poetry

.PHONY: clean-build
clean-build: ## clean build artifacts
@rm -rf dist
@python -c "import shutil; shutil.rmtree('dist', ignore_errors=True)"

.PHONY: publish
publish: ## publish a release to pypi.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tool_parse"
version = "0.2.1"
name = "tool-parse"
version = "0.2.2"
description = "Making LLM Tool-Calling Simpler."
authors = ["Harsh Verma <synacktra.work@gmail.com>"]
repository = "https://github.com/synacktraa/tool-parse"
Expand Down
315 changes: 224 additions & 91 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import typing as t
from enum import Enum
from pathlib import Path

import pytest

from tool_parse import ToolRegistry
from tool_parse.compile import CompileError
from tool_parse.marshal import MarshalError


# Basic registry fixture
@pytest.fixture
def registry():
def basic_registry():
tr = ToolRegistry()

@tr.register
Expand Down Expand Up @@ -45,117 +50,245 @@ class HeroData(t.NamedTuple):

tr["HeroData"] = HeroData

yield tr
return tr


def test_tool_registry(registry):
assert len(registry) == 4
# Complex registry fixture
@pytest.fixture
def complex_registry():
tr = ToolRegistry()

assert "get_flight_times" in registry
assert "CallApi" in registry
assert "user_info" in registry
assert "HeroData" in registry
@tr.register
def process_data(
text: str,
count: int,
ratio: float,
is_valid: bool,
tags: set[str],
items: list[str],
metadata: dict[str, t.Any],
file_path: Path,
optional_param: t.Optional[int] = None,
) -> dict:
"""
Process various types of data.
:param text: A string input
:param count: An integer count
:param ratio: A float ratio
:param is_valid: A boolean flag
:param tags: A set of string tags
:param items: A list of string items
:param metadata: A dictionary of metadata
:param file_path: A file path
:param optional_param: An optional integer parameter
"""
return {
"text_length": len(text),
"count_squared": count**2,
"ratio_rounded": round(ratio, 2),
"is_valid": is_valid,
"unique_tags": len(tags),
"items_count": len(items),
"metadata_keys": list(metadata.keys()),
"file_name": file_path.name,
"optional_param": optional_param,
}

with pytest.raises(KeyError):
registry["some_tool"]
class ColorEnum(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

@tr.register
def color_brightness(color: ColorEnum, brightness: t.Literal["light", "dark"]) -> str:
"""
Get color brightness description.
:param color: The color enum
:param brightness: The brightness level
"""
return f"{brightness} {color.value}"

def test_registry_marshal_method(registry):
tools = registry.marshal("base")
class UserProfile(t.TypedDict):
name: str
age: int
hobbies: t.List[str]

assert tools is not None
assert not isinstance(tools, str)
@tr.register
def create_user_profile(profile: UserProfile) -> str:
"""
Create a user profile.
:param profile: The user profile data
"""
return f"{profile['name']} ({profile['age']}) likes {', '.join(profile['hobbies'])}"

assert tools[0]["type"] == "function"
assert tools[0]["function"]["name"] == "get_flight_times"
assert tools[0]["function"]["description"] == "Get flight times."
assert tools[0]["function"]["parameters"]["type"] == "object"
assert tools[0]["function"]["parameters"]["required"] == ["departure", "arrival"]
assert tools[0]["function"]["parameters"]["properties"]["departure"]["type"] == "string"
assert (
tools[0]["function"]["parameters"]["properties"]["departure"]["description"]
== "Departure location code"
)
assert tools[0]["function"]["parameters"]["properties"]["arrival"]["type"] == "string"
assert (
tools[0]["function"]["parameters"]["properties"]["arrival"]["description"]
== "Arrival location code"
)
class BookInfo(t.NamedTuple):
title: str
author: str
year: int

assert tools[1]["type"] == "function"
assert tools[1]["function"]["name"] == "CallApi"
assert tools[1]["function"]["description"] == "Call the API."
assert tools[1]["function"]["parameters"]["type"] == "object"
assert tools[1]["function"]["parameters"]["required"] == ["host", "port"]
assert tools[1]["function"]["parameters"]["properties"]["host"]["type"] == "string"
assert tools[1]["function"]["parameters"]["properties"]["host"]["description"] == "Target host."
assert tools[1]["function"]["parameters"]["properties"]["port"]["type"] == "integer"
assert (
tools[1]["function"]["parameters"]["properties"]["port"]["description"]
== "Port number to request."
)
@tr.register
def format_book_info(book: BookInfo) -> str:
"""
Format book information.
:param book: The book information
"""
return f"{book.title} by {book.author} ({book.year})"

assert tools[2]["type"] == "function"
assert tools[2]["function"]["name"] == "user_info"
assert tools[2]["function"]["description"] == "Information of the user."
assert tools[2]["function"]["parameters"]["type"] == "object"
assert tools[2]["function"]["parameters"]["required"] == ["name"]
assert tools[2]["function"]["parameters"]["properties"]["name"]["type"] == "string"
assert tools[2]["function"]["parameters"]["properties"]["role"]["type"] == "string"
assert tools[2]["function"]["parameters"]["properties"]["role"]["enum"] == ["admin", "tester"]

assert tools[3]["type"] == "function"
assert tools[3]["function"]["name"] == "HeroData"
assert tools[3]["function"]["parameters"]["type"] == "object"
assert tools[3]["function"]["parameters"]["required"] == ["info"]
assert tools[3]["function"]["parameters"]["properties"]["info"]["type"] == "object"
assert (
tools[3]["function"]["parameters"]["properties"]["info"]["properties"]["name"]["type"]
== "string"
return tr


# Tests for basic registry
def test_basic_registry_content(basic_registry):
assert len(basic_registry) == 4
assert all(
tool in basic_registry for tool in ["get_flight_times", "CallApi", "user_info", "HeroData"]
)

with pytest.raises(KeyError):
basic_registry["non_existent_tool"]


def test_basic_registry_marshal(basic_registry):
tools = basic_registry.marshal("base")
assert len(tools) == 4

flight_tool = next(tool for tool in tools if tool["function"]["name"] == "get_flight_times")
assert flight_tool["function"]["parameters"]["required"] == ["departure", "arrival"]

user_info_tool = next(tool for tool in tools if tool["function"]["name"] == "user_info")
assert user_info_tool["function"]["parameters"]["properties"]["role"]["enum"] == [
"admin",
"tester",
]


def test_basic_registry_compile(basic_registry):
assert (
tools[3]["function"]["parameters"]["properties"]["info"]["properties"]["age"]["type"]
== "integer"
basic_registry.compile(
name="get_flight_times", arguments={"departure": "NYC", "arrival": "JFK"}
)
== "2 hours"
)
assert tools[3]["function"]["parameters"]["properties"]["info"]["required"] == ["name"]
assert tools[3]["function"]["parameters"]["properties"]["powers"]["type"] == "array"
assert tools[3]["function"]["parameters"]["properties"]["powers"]["items"]["type"] == "string"

user_info = basic_registry.compile(name="user_info", arguments={"name": "Alice"})
assert user_info == {"name": "Alice", "role": "tester"}

def test_registry_compile_method(registry):
get_flight_times_output = registry.compile(
name="get_flight_times", arguments={"departure": "NYC", "arrival": "JFK"}
)
assert get_flight_times_output == "2 hours"

call_api_output = registry.compile(
name="CallApi", arguments={"host": "localhost", "port": 8080}
# Tests for complex registry
def test_complex_registry_process_data(complex_registry):
result = complex_registry.compile(
name="process_data",
arguments={
"text": "Hello, World!",
"count": 5,
"ratio": 3.14159,
"is_valid": True,
"tags": ["python", "testing", "testing"],
"items": ["item1", "item2", "item3"],
"metadata": {"key1": "value1", "key2": "value2"},
"file_path": "../test.txt",
"optional_param": 42,
},
)
assert call_api_output.get("status") == "ok"
assert result["text_length"] == 13
assert result["count_squared"] == 25
assert result["ratio_rounded"] == 3.14
assert result["unique_tags"] == 2
assert result["file_name"] == "test.txt"


UserInfo_1_output = registry.compile(
name="user_info", arguments={"name": "Andrej", "role": "admin"}
def test_complex_registry_enum_and_literal(complex_registry):
result = complex_registry.compile(
name="color_brightness", arguments={"color": "RED", "brightness": "light"}
)
assert UserInfo_1_output["name"] == "Andrej"
assert UserInfo_1_output["role"] == "admin"
assert result == "light red"

UserInfo_2_output = registry.compile("user_info(name='Synacktra')")
assert UserInfo_2_output["name"] == "Synacktra"
assert UserInfo_2_output["role"] == "tester"
with pytest.raises(CompileError):
complex_registry.compile(
name="color_brightness", arguments={"color": "YELLOW", "brightness": "light"}
)

HeroData_1_output = registry.compile(
name="HeroData",
with pytest.raises(CompileError):
complex_registry.compile(
name="color_brightness", arguments={"color": "RED", "brightness": "medium"}
)


def test_complex_registry_typed_dict(complex_registry):
result = complex_registry.compile(
name="create_user_profile",
arguments={
"info": {"name": "homelander", "age": "1"},
"powers": ["bullying", "laser beam"],
"profile": {"name": "Alice", "age": 30, "hobbies": ["reading", "hiking", "photography"]}
},
)
assert HeroData_1_output[0]["name"] == "homelander"
assert HeroData_1_output.info["age"] == 1
assert HeroData_1_output.powers[0] == "bullying"
assert HeroData_1_output.powers[1] == "laser beam"

HeroData_2_output = registry.compile(name="HeroData", arguments={"info": {"name": "Man"}})
assert HeroData_2_output.info["name"] == "Man"
assert HeroData_2_output[0]["age"] is None
assert HeroData_2_output.powers is None
assert result == "Alice (30) likes reading, hiking, photography"

with pytest.raises(CompileError):
complex_registry.compile(
name="create_user_profile",
arguments={"profile": {"name": "Bob", "hobbies": ["coding"]}},
)


def test_complex_registry_named_tuple(complex_registry):
result = complex_registry.compile(
name="format_book_info",
arguments={"book": {"title": "1984", "author": "George Orwell", "year": 1949}},
)
assert result == "1984 by George Orwell (1949)"


# Tests for MarshalError
def test_marshal_error_unsupported_type():
tr = ToolRegistry()

@tr.register
def process_bytes(data: bytes) -> str:
"""Process byte data"""
return data.decode()

with pytest.raises(MarshalError):
_ = tr.marshal("base")


def test_marshal_error_complex_unsupported_type():
tr = ToolRegistry()

@tr.register
def process_complex_data(data: t.List[bytes]) -> str:
"""Process complex data with unsupported type"""
return str(len(data))

with pytest.raises(MarshalError):
_ = tr.marshal("base")


# Additional tests for edge cases
def test_empty_registry():
tr = ToolRegistry()
assert len(tr) == 0
assert tr.marshal("base") is None


def test_registry_addition():
tr1 = ToolRegistry()
tr2 = ToolRegistry()

@tr1.register
def func1():
pass

@tr2.register
def func2():
pass

combined = tr1 + tr2
assert len(combined) == 2
assert "func1" in combined and "func2" in combined


def test_marshal_as_json(complex_registry):
json_output = complex_registry.marshal(as_json=True)
assert isinstance(json_output, str)
assert "process_data" in json_output
assert "color_brightness" in json_output
3 changes: 3 additions & 0 deletions tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ def get_flight_times(departure: str, arrival: str) -> str:
== "Arrival location code"
)

# passing dictionary as arguments
assert get_flight_times.compile(arguments={"departure": "NYC", "arrival": "JFK"}) == "2 hours"

# passing json as arguments
assert get_flight_times.compile(arguments='{"departure": "NYC", "arrival": "JFK"}') == "2 hours"

# passing call expression
assert get_flight_times.compile("get_flight_times(departure='NYC', arrival='JFK')") == "2 hours"

with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit de0b1ef

Please sign in to comment.