Skip to content

Commit

Permalink
Added test for custom client
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen committed Nov 8, 2024
1 parent 4633aa4 commit 63ace19
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
6 changes: 5 additions & 1 deletion giskard/llm/client/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def __init__(self, model: str = "gpt-4o", completion_params: Optional[Dict[str,
def _build_supported_completion_params(self, **kwargs):
supported_params = litellm.get_supported_openai_params(model=self.model)

return {param_name: param_value for param_name, param_value in kwargs.items() if param_name in supported_params}
return {
param_name: param_value
for param_name, param_value in kwargs.items()
if supported_params is None or param_name in supported_params
}

def complete(
self,
Expand Down
32 changes: 32 additions & 0 deletions tests/llm/test_llm_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from unittest.mock import Mock, patch

import litellm
import pydantic
import pytest
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice

from giskard.llm import get_default_client, set_llm_model
from giskard.llm.client import ChatMessage
from giskard.llm.client.litellm import LiteLLMClient
from giskard.llm.client.openai import OpenAIClient
Expand Down Expand Up @@ -64,3 +66,33 @@ def test_litellm_client(completion):

assert isinstance(res, ChatMessage)
assert res.content == "This is a test!"


API_KEY = "MOCK_API_KEY"


class MockLLM(litellm.CustomLLM):
def completion(self, model: str, messages: list, api_key: str, **kwargs) -> litellm.ModelResponse:
assert api_key == API_KEY, "Completion params are not passed properly"

return litellm.ModelResponse(
choices=[
litellm.Choices(
model=model,
message=litellm.Message(role="assistant", content=f"Mock response - {messages[-1].get('content')}"),
)
]
)


litellm.custom_provider_map = litellm.custom_provider_map + [{"provider": "mock", "custom_handler": MockLLM()}]


@pytest.mark.skipif(not PYDANTIC_V2, reason="LiteLLM raise an error with pydantic < 2")
def test_litellm_client_custom_model():
set_llm_model("mock/faux-bot", api_key=API_KEY)

llm_client = get_default_client()
message = "Mock input"
response = llm_client.complete([ChatMessage(role="user", content=message)])
assert f"Mock response - {message}" == response.content

0 comments on commit 63ace19

Please sign in to comment.