Skip to content

Commit

Permalink
fix: support unit tests for chains using embeddings (GoogleCloudPlatf…
Browse files Browse the repository at this point in the history
…orm#1275)

Allow unit tests to seamless support chains using embeddings without
raising errors due mocking
  • Loading branch information
eliasecchig authored Oct 17, 2024
1 parent af84db6 commit b1249a4
Showing 1 changed file with 34 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
# pylint: disable=W0707, C0415

import importlib.util
import json
import logging
import os
from typing import Any, Generator
from unittest.mock import MagicMock, patch

from app.utils.input_types import InputChat
from google.auth import exceptions as google_auth_exceptions
from google.auth.credentials import Credentials
from httpx import AsyncClient
from langchain_core.messages import HumanMessage
Expand Down Expand Up @@ -65,6 +67,31 @@ def sample_input_chat() -> InputChat:
)


@pytest.fixture(autouse=True)
def mock_dependencies() -> Generator[None, None, None]:
"""
Mock Vertex AI dependencies for testing.
Patches VertexAIEmbeddings (if defined) and ChatVertexAI.
"""
patches = []
try:
try:
importlib.util.find_spec("app.chain.VertexAIEmbeddings")
except (ModuleNotFoundError, google_auth_exceptions.DefaultCredentialsError):
pass
else:
patches.append(patch("app.chain.VertexAIEmbeddings"))
patches.append(patch("app.chain.ChatVertexAI"))

for patch_item in patches:
mock = patch_item.start()
mock.return_value = MagicMock()

yield
except google_auth_exceptions.GoogleAuthError:
yield


class AsyncIterator:
"""
A helper class to create asynchronous iterators for testing.
Expand All @@ -87,13 +114,14 @@ def test_redirect_root_to_docs() -> None:
"""
Test that the root endpoint (/) redirects to the Swagger UI documentation.
"""
from app.server import app
from fastapi.testclient import TestClient
with patch("app.server.chain") as _:
from app.server import app
from fastapi.testclient import TestClient

client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert "Swagger UI" in response.text
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert "Swagger UI" in response.text


@pytest.mark.asyncio
Expand Down

0 comments on commit b1249a4

Please sign in to comment.