From 1411b366300ad996867e91b98f516fdae20946dd Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Fri, 27 Sep 2024 16:32:33 +0200 Subject: [PATCH] Removed mock settings --- src/neuroagent/agents/simple_chat_agent.py | 2 +- src/neuroagent/app/dependencies.py | 4 +- tests/app/test_dependencies.py | 112 ++++++++++++--------- 3 files changed, 70 insertions(+), 48 deletions(-) diff --git a/src/neuroagent/agents/simple_chat_agent.py b/src/neuroagent/agents/simple_chat_agent.py index 274d18b..b331d75 100644 --- a/src/neuroagent/agents/simple_chat_agent.py +++ b/src/neuroagent/agents/simple_chat_agent.py @@ -17,7 +17,7 @@ class SimpleChatAgent(BaseAgent): """Simple Agent class.""" - memory: BaseCheckpointSaver[Any] + memory: BaseCheckpointSaver @model_validator(mode="before") @classmethod diff --git a/src/neuroagent/app/dependencies.py b/src/neuroagent/app/dependencies.py index 4d875ba..be00639 100644 --- a/src/neuroagent/app/dependencies.py +++ b/src/neuroagent/app/dependencies.py @@ -321,7 +321,7 @@ def get_language_model( async def get_agent_memory( connection_string: Annotated[str | None, Depends(get_connection_string)], -) -> AsyncIterator[BaseCheckpointSaver[Any] | None]: +) -> AsyncIterator[BaseCheckpointSaver | None]: """Get the agent checkpointer.""" if connection_string: if connection_string.startswith("sqlite"): @@ -404,7 +404,7 @@ def get_agent( def get_chat_agent( llm: Annotated[ChatOpenAI, Depends(get_language_model)], - memory: Annotated[BaseCheckpointSaver[Any], Depends(get_agent_memory)], + memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)], literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)], br_resolver_tool: Annotated[ ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool) diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index 66161e2..67d3e05 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -44,7 +44,6 @@ LiteratureSearchTool, MorphologyFeatureTool, ) -from pydantic import Secret from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session @@ -387,73 +386,82 @@ async def test_get_cell_types_kg_hierarchy( assert os.path.exists(settings.knowledge_graph.ct_saving_path) -def fake_get_settings(): - class MockedDb: - def __init__(self, prefix, user, password, host, port, name): - self.prefix = prefix - self.user = user - self.password = Secret(password) - self.host = host - self.port = port - self.name = name - - class MockedSettings: - def __init__(self, db): - self.db = db - - return [ - MockedSettings(MockedDb("http://", "John", "Doe", "localhost", 5000, "test")), - MockedSettings(MockedDb("", "", "", "", None, None)), - ] - +def test_get_connection_string_full(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "http://") + monkeypatch.setenv("NEUROAGENT_DB__USER", "John") + monkeypatch.setenv("NEUROAGENT_DB__PASSWORD", "Doe") + monkeypatch.setenv("NEUROAGENT_DB__HOST", "localhost") + monkeypatch.setenv("NEUROAGENT_DB__PORT", "5000") + monkeypatch.setenv("NEUROAGENT_DB__NAME", "test") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") -def test_get_connection_string_full(): - settings = fake_get_settings()[0] + settings = get_settings() result = get_connection_string(settings) assert ( result == "http://John:Doe@localhost:5000/test" ), "must return fully formed connection string" -def test_get_connection_string_no_prefix(): - settings = fake_get_settings()[1] +def test_get_connection_string_no_prefix(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + + settings = get_settings() result = get_connection_string(settings) assert result is None, "should return None when prefix is not set" @patch("neuroagent.app.dependencies.create_engine") -def test_get_engine(create_engine_mock): +def test_get_engine(create_engine_mock, monkeypatch): create_engine_mock.return_value = Mock() - settings = Mock() - settings.db = Mock() - settings.db.prefix = "prefix" - settings.db.password = None + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + + settings = get_settings() + connection_string = "https://localhost" retval = get_engine(settings=settings, connection_string=connection_string) assert retval is not None @patch("neuroagent.app.dependencies.create_engine") -def test_get_engine_no_connection_string(create_engine_mock): +def test_get_engine_no_connection_string(create_engine_mock, monkeypatch): create_engine_mock.return_value = Mock() - settings = Mock() - settings.db = Mock() - settings.db.prefix = "prefix" - settings.db.password = None + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + + settings = get_settings() + retval = get_engine(settings=settings, connection_string=None) assert retval is None @patch("neuroagent.app.dependencies.create_engine") -def test_get_engine_error(create_engine_mock): +def test_get_engine_error(create_engine_mock, monkeypatch): create_engine_mock.side_effect = SQLAlchemyError("An error occurred") - settings = Mock() - settings.db = Mock() - settings.db.prefix = "prefix" - settings.db.password = None + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + + settings = get_settings() + connection_string = "https://localhost" with pytest.raises(SQLAlchemyError): get_engine(settings=settings, connection_string=connection_string) @@ -471,17 +479,31 @@ def test_get_session_no_engine(): next(get_session(None)) -def test_get_kg_token_with_token(): - settings = Mock() +def test_get_kg_token_with_token(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + + settings = get_settings() + token = "Test_Token" result = get_kg_token(settings, token) assert result == "Test_Token" -def test_get_kg_token_with_settings_knowledge_graph_token(): - settings = Mock() - settings.knowledge_graph.use_token = True - settings.knowledge_graph.token.get_secret_value.return_value = "Test_kg_Token" +def test_get_kg_token_with_settings_knowledge_graph_token(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://localhost") + monkeypatch.setenv("NEUROAGENT_DB__PREFIX", "prefix") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__USERNAME", "fake_username") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "fake_password") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN", "true") + monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN", "Test_kg_Token") + + settings = get_settings() + token = None result = get_kg_token(settings, token)