Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds LlamaCpp LLM #96

Merged
merged 16 commits into from
Jan 13, 2024
699 changes: 699 additions & 0 deletions docs/05-local-execution.ipynb

Large diffs are not rendered by default.

114 changes: 73 additions & 41 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ pinecone-text = {version = "^0.7.1", optional = true}
fastembed = {version = "^0.1.3", optional = true, python = "<3.12"}
torch = {version = "^2.1.2", optional = true}
transformers = {version = "^4.36.2", optional = true}
llama-cpp-python = {version = "^0.2.28", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
fastembed = ["fastembed"]
local = ["torch", "transformers"]
local = ["torch", "transformers", "llama-cpp-python"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM

__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM"]
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "LlamaCppLLM"]
74 changes: 73 additions & 1 deletion semantic_router/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional
import json
from typing import Any, List, Optional

from pydantic import BaseModel

from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class BaseLLM(BaseModel):
Expand All @@ -11,5 +13,75 @@
class Config:
arbitrary_types_allowed = True

def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)

def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")

def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]

for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False

def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
) -> dict:
logger.info("Extracting function input...")

prompt = f"""
You are a helpful assistant designed to output JSON.
Given the following function schema
<< {function_schema} >>
and query
<< {query} >>
extract the parameters values from the query, in a valid JSON format.
Example:
Input:
query: "How is the weather in Hawaii right now in International units?"
schema:
{{
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str, degree: str) -> str",
"output": "<class 'str'>",
}}

Result: {{
"location": "London",
"degree": "Celsius",
}}

Input:
query: {query}
schema: {function_schema}
Result:
"""
llm_input = [Message(role="user", content=prompt)]
output = self(llm_input)
if not output:
raise Exception("No output generated for extract function input")

Check warning on line 80 in semantic_router/llms/base.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/llms/base.py#L80

Added line #L80 was not covered by tests

output = output.replace("'", '"').strip().rstrip(",")

function_inputs = json.loads(output)
if not self._is_valid_inputs(function_inputs, function_schema):
raise ValueError("Invalid inputs")
return function_inputs
25 changes: 25 additions & 0 deletions semantic_router/llms/grammars/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading