Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds support for the Unify AI #412

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
fe3ef96
Create unify.py
Yara97Mansour Aug 21, 2024
efb310f
Update __init__.py
Yara97Mansour Aug 21, 2024
babef94
Update defaults.py
Yara97Mansour Aug 21, 2024
d91bc3b
Update unify.py
Yara97Mansour Aug 21, 2024
918c4f9
Update unify.py
Yara97Mansour Aug 21, 2024
812dd77
Update unify.py
Yara97Mansour Aug 21, 2024
0864077
updates to EncodeDefault.UNIFY and unify.py
Kacper-W-Kozdon Aug 22, 2024
428bf73
adds test
Kacper-W-Kozdon Aug 22, 2024
ed60369
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
691edb8
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
bf312a4
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
434203b
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
636496f
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
0886aaf
updates tests
Kacper-W-Kozdon Aug 22, 2024
533d045
updates unify.py and tests
Kacper-W-Kozdon Aug 22, 2024
577e0fd
updates tests
Kacper-W-Kozdon Aug 22, 2024
8ce1bce
Merge pull request #1 from Kacper-W-Kozdon/main-Kacper
Yara97Mansour Aug 22, 2024
c858590
Update unify.py
Yara97Mansour Aug 22, 2024
1a379e3
Update defaults.py
Yara97Mansour Aug 22, 2024
225f103
Update unify.py
Yara97Mansour Aug 22, 2024
38f81a6
Update unify.py
Yara97Mansour Aug 22, 2024
9b39b79
Update unify.py
Yara97Mansour Aug 22, 2024
cafbe8a
Update pyproject.toml
Yara97Mansour Aug 23, 2024
58d8e91
Update pyproject.toml
Yara97Mansour Aug 23, 2024
c5a245a
Update pyproject.toml
Yara97Mansour Aug 23, 2024
7407033
Update unify.py
Yara97Mansour Aug 25, 2024
0d49633
adds mock to tests
Kacper-W-Kozdon Aug 27, 2024
4e9a811
fixing pytest mock
Kacper-W-Kozdon Aug 27, 2024
06369b5
fixing pytest mock
Kacper-W-Kozdon Aug 27, 2024
7b624c7
fixing pytest mock
Kacper-W-Kozdon Aug 27, 2024
bec1b47
fixing pytest mock
Kacper-W-Kozdon Aug 27, 2024
861adc3
rewrites __call__ as a generator for _call | _acall, updates mock tests
Kacper-W-Kozdon Aug 27, 2024
5d95892
rewrites __call__ as a generator for _call | _acall, updates mock tests
Kacper-W-Kozdon Aug 27, 2024
fb2c83c
Update poetry.lock
Yara97Mansour Aug 27, 2024
662faff
Update test_llm_unify.py
Yara97Mansour Aug 27, 2024
dfc352d
Update test_llm_unify.py
Yara97Mansour Aug 27, 2024
3254d73
Update test_llm_unify.py
Yara97Mansour Aug 27, 2024
4cb4d17
update
Kacper-W-Kozdon Aug 28, 2024
f11bae3
Merge branch 'main-Kacper-old' into main-Kacper
Kacper-W-Kozdon Aug 28, 2024
eb3ca08
cleans errors after merge
Kacper-W-Kozdon Aug 28, 2024
20b3968
pytest mock passing 6/6
Kacper-W-Kozdon Aug 28, 2024
7f1f3bf
pytest mock passing 7/7
Kacper-W-Kozdon Aug 28, 2024
0e9de87
Merge pull request #2 from Kacper-W-Kozdon/main-Kacper
Yara97Mansour Aug 28, 2024
bb8f7d3
Update test_llm_unify.py
Yara97Mansour Aug 28, 2024
c53db87
Update test_llm_unify.py
Yara97Mansour Aug 28, 2024
3ccffcb
Update test_llm_unify.py
Yara97Mansour Aug 28, 2024
a874be1
Update test_llm_unify.py
Yara97Mansour Aug 28, 2024
2cba0c2
applies pre-commit
Kacper-W-Kozdon Aug 28, 2024
5feae17
Merge pull request #3 from Kacper-W-Kozdon/main-Kacper-review
Yara97Mansour Aug 28, 2024
552cad0
Merge branch 'aurelio-labs:main' into main
Yara97Mansour Aug 28, 2024
af64e75
Update unify.py
Yara97Mansour Aug 29, 2024
dd6f251
Merge branch 'aurelio-labs:main' into main
Yara97Mansour Aug 29, 2024
e363892
Merge branch 'main' into main-semantic-router
Kacper-W-Kozdon Sep 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = ">=1.10.0,<2.0.0"
cohere = ">=5.00,<6.00"
mistralai= {version = ">=0.0.12,<0.1.0", optional = true}
mistralai = {version = ">=0.0.12,<0.1.0", optional = true}
unifyai = "^0.9.1"
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from semantic_router.llms.mistral import MistralAILLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
from semantic_router.llms.unify import UnifyLLM
from semantic_router.llms.zure import AzureOpenAILLM

__all__ = [
Expand All @@ -14,4 +15,5 @@
"CohereLLM",
"AzureOpenAILLM",
"MistralAILLM",
"UnifyLLM",
]
81 changes: 81 additions & 0 deletions semantic_router/llms/unify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio # noqa: F401
from typing import List, Optional, Coroutine, Callable, Any, Union

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault

from unify.exceptions import UnifyError
from unify.clients import Unify, AsyncUnify


class UnifyLLM(BaseLLM):
client: Optional[Unify]
async_client: Optional[AsyncUnify]
temperature: Optional[float]
max_tokens: Optional[int]
stream: Optional[bool]
Async: Optional[bool]

def __init__(
self,
name: Optional[str] = None,
unify_api_key: Optional[str] = None,
temperature: Optional[float] = 0.01,
max_tokens: Optional[int] = 200,
stream: bool = False,
Async: bool = False,
):
if name is None:
name = (f"{EncoderDefault.UNIFY.value['language_model']}"+
f"@{EncoderDefault.UNIFY.value['language_provider']}")

super().__init__(name=name)
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.client = Unify(endpoint=name, api_key=unify_api_key)
self.async_client = AsyncUnify(endpoint=name, api_key=unify_api_key)
self.Async = Async # noqa: C0103

def __call__(self, messages: List[Message]) -> Any:
func: Union[Callable[..., str], Callable[..., Coroutine[Any, Any, str]]] = (
self._call if not self.Async else self._acall
)
return func(messages)

def _call(self, messages: List[Message]) -> str:
if self.client is None:
raise UnifyError("Unify client is not initialized.")
try:
output = self.client.generate(
messages=[m.to_openai() for m in messages],
max_tokens=self.max_tokens,
temperature=self.temperature,
stream=self.stream,
)

if not output:
raise UnifyError("No output generated")
return output

except Exception as e:
raise UnifyError(f"Unify API call failed. Error: {e}") from e

async def _acall(self, messages: List[Message]) -> str:
if self.async_client is None:
raise UnifyError("Unify async_client is not initialized.")
try:
output = await self.async_client.generate(
messages=[m.to_openai() for m in messages],
max_tokens=self.max_tokens,
temperature=self.temperature,
stream=self.stream,
)

if not output:
raise UnifyError("No output generated")
return output

except Exception as e:
raise UnifyError(f"Unify API call failed. Error: {e}") from e
10 changes: 9 additions & 1 deletion semantic_router/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,13 @@ class EncoderDefault(Enum):
BEDROCK = {
"embedding_model": os.environ.get(
"BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1"
)
),
}
UNIFY = {
"language_model": os.environ.get(
"UNIFY_CHAT_MODEL_NAME", "llama-3-8b-chat"
),
"language_provider": os.environ.get(
"UNIFY_CHAT_MODEL_PROVIDER", "together-ai"
),
}
65 changes: 65 additions & 0 deletions tests/unit/llms/test_llm_unify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from semantic_router.llms.unify import UnifyLLM
from semantic_router.schema import Message

from unify.clients import Unify, AsyncUnify
from unify.exceptions import UnifyError


@pytest.fixture
def unify_llm(mocker):
mocker.patch("unify.clients.Unify")
# mocker.patch("json.loads", return_value=["llama-3-8b-chat@together-ai"])
mocker.patch.object(Unify, "set_endpoint", return_value=None)
mocker.patch.object(AsyncUnify, "set_endpoint", return_value=None)

return UnifyLLM(unify_api_key="fake-api-key")


class TestUnifyLLM:
# def test_unify_llm_init_success_1(self, unify_llm, mocker):
# mocker.patch("os.getenv", return_value="fake-api-key")
# mocker.patch.object(unify_llm.client, "set_endpoint", return_value=None)

# assert unify_llm.client is not None

def test_unify_llm_init_success(self, unify_llm):
# mocker.patch("os.getenv", return_value="fake-api-key")
assert unify_llm.name == "llama-3-8b-chat@together-ai"
assert unify_llm.temperature == 0.01
assert unify_llm.max_tokens == 200
assert unify_llm.stream is False

def test_unify_llm_init_with_api_key(self, unify_llm):
assert unify_llm.client is not None, "Client should be initialized"
assert (
unify_llm.name == "llama-3-8b-chat@together-ai"
), "Default name not set correctly"

def test_unify_llm_init_without_api_key(self, mocker):
mocker.patch("os.environ.get", return_value=None)
with pytest.raises(KeyError) as _:
UnifyLLM()

def test_unify_llm_call_uninitialized_client(self, unify_llm):
unify_llm.client = None
with pytest.raises(UnifyError) as e:
llm_input = [Message(role="user", content="test")]
unify_llm(llm_input)
assert "Unify client is not initialized." in str(e.value)

def test_unify_llm_error_handling(self, unify_llm, mocker):
mocker.patch.object(
unify_llm.client, "generate", side_effect=Exception("LLM error")
)
with pytest.raises(UnifyError) as exc_info:
unify_llm([Message(role="user", content="test")])
assert "LLM error" in f"{str(exc_info)}, {str(exc_info.value)}"

def test_unify_llm_call_success(self, unify_llm, mocker):
mock_response = "test response"
mocker.patch.object(unify_llm.client, "generate", return_value=mock_response)

output = unify_llm([Message(role="user", content="test")])
assert output == "test response"