Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
bruvduroiu committed Jan 13, 2024
1 parent 955739c commit 708b63b
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 16 deletions.
11 changes: 9 additions & 2 deletions docs/05-local-execution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"from semantic_router import Route\n",
"from semantic_router.utils.function_call import get_schema\n",
"\n",
"\n",
"def get_time(timezone: str) -> str:\n",
" \"\"\"Finds the current time in a specific timezone.\n",
"\n",
Expand All @@ -100,6 +101,7 @@
" now = datetime.now(ZoneInfo(timezone))\n",
" return now.strftime(\"%H:%M\")\n",
"\n",
"\n",
"time_schema = get_schema(get_time)\n",
"time_schema\n",
"time = Route(\n",
Expand Down Expand Up @@ -314,9 +316,14 @@
"from llama_cpp import Llama\n",
"from semantic_router.llms import LlamaCppLLM\n",
"\n",
"enable_gpu = True # offload LLM layers to the GPU (must fit in memory)\n",
"enable_gpu = True # offload LLM layers to the GPU (must fit in memory)\n",
"\n",
"_llm = Llama(model_path=\"./mistral-7b-instruct-v0.2.Q4_0.gguf\", n_gpu_layers=-1 if enable_gpu else 0, n_ctx=2048, verbose=False)\n",
"_llm = Llama(\n",
" model_path=\"./mistral-7b-instruct-v0.2.Q4_0.gguf\",\n",
" n_gpu_layers=-1 if enable_gpu else 0,\n",
" n_ctx=2048,\n",
" verbose=False,\n",
")\n",
"llm = LlamaCppLLM(name=\"Mistral-7B-v0.2-Instruct\", llm=_llm, max_tokens=None)\n",
"\n",
"rl = RouteLayer(encoder=encoder, routes=routes, llm=llm)"
Expand Down
12 changes: 9 additions & 3 deletions semantic_router/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ class Config:
def __call__(self, messages: list[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")

def _is_valid_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool:
def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [info.split(":")[1].strip().split("=")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]

for name, type_str in zip(param_names, param_types):
if name not in inputs:
Expand All @@ -34,7 +38,9 @@ def _is_valid_inputs(self, inputs: dict[str, Any], function_schema: dict[str, An
logger.error(f"Input validation error: {str(e)}")
return False

def extract_function_inputs(self, query: str, function_schema: dict[str, Any]) -> dict:
def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
) -> dict:
logger.info("Extracting function input...")

prompt = f"""
Expand Down
13 changes: 10 additions & 3 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from contextlib import contextmanager

from llama_cpp import Llama, LlamaGrammar

Expand All @@ -22,6 +22,8 @@ def __init__(
temperature: float = 0.2,
max_tokens: int = 200,
):
if not llm:
raise ValueError("`llama_cpp.Llama` llm is required")
super().__init__(name=name)
self.llm = llm
self.temperature = temperature
Expand All @@ -37,6 +39,7 @@ def __call__(
temperature=self.temperature,
max_tokens=self.max_tokens,
grammar=self.grammar,
stream=False,
)

output = completion["choices"][0]["message"]["content"]
Expand All @@ -58,6 +61,10 @@ def _grammar(self):
finally:
self.grammar = None

def extract_function_inputs(self, query: str, function_schema: dict[str, Any]) -> dict:
def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
) -> dict:
with self._grammar():
return super().extract_function_inputs(query=query, function_schema=function_schema)
return super().extract_function_inputs(
query=query, function_schema=function_schema
)
17 changes: 13 additions & 4 deletions semantic_router/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ def is_valid(route_config: str) -> bool:
for item in output_json:
missing_keys = [key for key in required_keys if key not in item]
if missing_keys:
logger.warning(f"Missing keys in route config: {', '.join(missing_keys)}")
logger.warning(
f"Missing keys in route config: {', '.join(missing_keys)}"
)
return False
return True
else:
missing_keys = [key for key in required_keys if key not in output_json]
if missing_keys:
logger.warning(f"Missing keys in route config: {', '.join(missing_keys)}")
logger.warning(
f"Missing keys in route config: {', '.join(missing_keys)}"
)
return False
else:
return True
Expand All @@ -44,9 +48,14 @@ class Route(BaseModel):
def __call__(self, query: str) -> RouteChoice:
if self.function_schema:
if not self.llm:
raise ValueError("LLM is required for dynamic routes. Please ensure the `llm` " "attribute is set.")
raise ValueError(
"LLM is required for dynamic routes. Please ensure the `llm` "
"attribute is set."
)
# if a function schema is provided we generate the inputs
extracted_inputs = self.llm.extract_function_inputs(query=query, function_schema=self.function_schema)
extracted_inputs = self.llm.extract_function_inputs(
query=query, function_schema=self.function_schema
)
func_call = extracted_inputs
else:
# otherwise we just pass None for the call
Expand Down
10 changes: 6 additions & 4 deletions semantic_router/utils/function_call.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import inspect
import json
from typing import Any, Callable, Union

from pydantic import BaseModel

from semantic_router.llms import BaseLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.schema import Message, RouteChoice
from semantic_router.utils.logger import logger

Expand All @@ -19,7 +17,9 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]:

if default_value:
default_repr = repr(default_value)
signature_part = f"{field_name}: {field_model.__name__} = {default_repr}"
signature_part = (
f"{field_name}: {field_model.__name__} = {default_repr}"
)
else:
signature_part = f"{field_name}: {field_model.__name__}"

Expand All @@ -41,7 +41,9 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]:


# TODO: Add route layer object to the input, solve circular import issue
async def route_and_execute(query: str, llm: BaseLLM, functions: list[Callable], layer) -> Any:
async def route_and_execute(
query: str, llm: BaseLLM, functions: list[Callable], layer
) -> Any:
route_choice: RouteChoice = layer(query)

for function in functions:
Expand Down

0 comments on commit 708b63b

Please sign in to comment.