From a5de6078fcc86075e0682c6e43b3769b06649dd4 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 8 Jan 2025 11:56:47 +0200 Subject: [PATCH 1/2] Remove sqlc and its classes Closes: #512 This PR drops sqlc related-files and adds the needed code to use directly SQLAlchemy. It most parts of the code we were already using it, hence the changes are not so drastic. See #512 issue for the reasoining on why to remove `sqlc`. --- poetry.lock | 4 +- pyproject.toml | 3 +- sql/queries/queries.sql | 24 ---- sqlc.yaml | 20 --- src/codegate/dashboard/post_processing.py | 7 +- src/codegate/db/connection.py | 88 +++++++++--- src/codegate/db/fim_cache.py | 27 ++-- src/codegate/db/models.py | 34 ++++- src/codegate/db/queries.py | 148 --------------------- src/codegate/pipeline/secrets/secrets.py | 2 +- src/codegate/providers/copilot/provider.py | 3 +- tests/dashboard/test_post_processing.py | 2 +- tests/db/test_fim_cache.py | 15 ++- 13 files changed, 134 insertions(+), 243 deletions(-) delete mode 100644 sql/queries/queries.sql delete mode 100644 sqlc.yaml delete mode 100644 src/codegate/db/queries.py diff --git a/poetry.lock b/poetry.lock index 7acd37ca..64ab8184 100644 --- a/poetry.lock +++ b/poetry.lock @@ -585,6 +585,7 @@ files = [ {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"}, + {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"}, {file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"}, @@ -595,6 +596,7 @@ files = [ {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"}, + {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"}, {file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"}, @@ -3056,4 +3058,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.12,<4.0" -content-hash = "28f2781f75dc249b7d30d27e78e02c13e07c11bdd71f52c22aab043d301ba5c4" +content-hash = "647532c2d43b0cd85705e0e91a9a117339f595da7549b47419d35d7b40f24dca" diff --git a/pyproject.toml b/pyproject.toml index d391e09c..5772d396 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ litellm = "^1.57.1" llama_cpp_python = ">=0.3.2" cryptography = "^44.0.0" sqlalchemy = "^2.0.28" -greenlet = "^3.0.3" aiosqlite = "^0.20.0" ollama = ">=0.4.4" pydantic-settings = "^2.7.1" @@ -27,8 +26,8 @@ tree-sitter-java = ">=0.23.5" tree-sitter-javascript = ">=0.23.1" tree-sitter-python = ">=0.23.6" tree-sitter-rust = ">=0.23.2" - sqlite-vec-sl-tmp = "^0.0.4" + [tool.poetry.group.dev.dependencies] pytest = ">=7.4.0" pytest-cov = ">=4.1.0" diff --git a/sql/queries/queries.sql b/sql/queries/queries.sql deleted file mode 100644 index eba9e83d..00000000 --- a/sql/queries/queries.sql +++ /dev/null @@ -1,24 +0,0 @@ --- name: GetPromptWithOutputs :many -SELECT - p.*, - o.id as output_id, - o.output, - o.timestamp as output_timestamp -FROM prompts p -LEFT JOIN outputs o ON p.id = o.prompt_id -ORDER BY o.timestamp DESC; - --- name: GetAlertsWithPromptAndOutput :many -SELECT - a.*, - p.timestamp as prompt_timestamp, - p.provider, - p.request, - p.type, - o.id as output_id, - o.output, - o.timestamp as output_timestamp -FROM alerts a -LEFT JOIN prompts p ON p.id = a.prompt_id -LEFT JOIN outputs o ON p.id = o.prompt_id -ORDER BY a.timestamp DESC; diff --git a/sqlc.yaml b/sqlc.yaml deleted file mode 100644 index 563a57a4..00000000 --- a/sqlc.yaml +++ /dev/null @@ -1,20 +0,0 @@ -version: "2" -plugins: - - name: "python" - wasm: - url: "https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.2.0.wasm" - sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e" - -sql: - - engine: "sqlite" - schema: "sql/schema" - queries: "sql/queries" - codegen: - - plugin: "python" - out: "src/codegate/db" - options: - package: "codegate.db" - emit_sync_querier: true - emit_async_querier: true - query_parameter_limit: 5 - emit_pydantic_models: true diff --git a/src/codegate/dashboard/post_processing.py b/src/codegate/dashboard/post_processing.py index 4d93af33..9d5bbeb6 100644 --- a/src/codegate/dashboard/post_processing.py +++ b/src/codegate/dashboard/post_processing.py @@ -12,7 +12,7 @@ PartialConversation, QuestionAnswer, ) -from codegate.db.queries import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow +from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow logger = structlog.get_logger("codegate") @@ -183,7 +183,7 @@ async def parse_get_prompt_with_output( def parse_question_answer(input_text: str) -> str: # given a string, detect if we have a pattern of "Context: xxx \n\nQuery: xxx" and strip it - pattern = r'^Context:.*?\n\n\s*Query:\s*(.*)$' + pattern = r"^Context:.*?\n\n\s*Query:\s*(.*)$" # Search using the regex pattern match = re.search(pattern, input_text, re.DOTALL) @@ -226,7 +226,8 @@ async def match_conversations( if partial_conversation.question_answer.answer is not None: first_partial_conversation = partial_conversation partial_conversation.question_answer.question.message = parse_question_answer( - partial_conversation.question_answer.question.message) + partial_conversation.question_answer.question.message + ) questions_answers.append(partial_conversation.question_answer) # only add conversation if we have some answers diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index af7c3b98..443ab008 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,7 +1,7 @@ import asyncio import json from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Type import structlog from pydantic import BaseModel @@ -9,11 +9,12 @@ from sqlalchemy.ext.asyncio import create_async_engine from codegate.db.fim_cache import FimCache -from codegate.db.models import Alert, Output, Prompt -from codegate.db.queries import ( - AsyncQuerier, +from codegate.db.models import ( + Alert, GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow, + Output, + Prompt, ) from codegate.pipeline.base import PipelineContext @@ -83,11 +84,9 @@ async def init_db(self): await self._async_db_engine.dispose() async def _execute_update_pydantic_model( - self, model: BaseModel, sql_command: TextClause # + self, model: BaseModel, sql_command: TextClause ) -> Optional[BaseModel]: - # There are create method in queries.py automatically generated by sqlc - # However, the methods are buggy for Pydancti and don't work as expected. - # Manually writing the SQL query to insert Pydantic models. + """Execute an update or insert command for a Pydantic model.""" async with self._async_db_engine.begin() as conn: try: result = await conn.execute(sql_command, model.model_dump()) @@ -117,8 +116,9 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option # logger.debug(f"Recorded request: {recorded_request}") return recorded_request # type: ignore - async def update_request(self, initial_id: str, - prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: + async def update_request( + self, initial_id: str, prompt_params: Optional[Prompt] = None + ) -> Optional[Prompt]: if prompt_params is None: return None prompt_params.id = initial_id # overwrite the initial id of the request @@ -135,8 +135,9 @@ async def update_request(self, initial_id: str, # logger.debug(f"Recorded request: {recorded_request}") return updated_request # type: ignore - async def record_outputs(self, outputs: List[Output], - initial_id: Optional[str]) -> Optional[Output]: + async def record_outputs( + self, outputs: List[Output], initial_id: Optional[str] + ) -> Optional[Output]: if not outputs: return @@ -216,7 +217,7 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> tuple: # If it's not a FIM prompt, we don't need to check anything else. if context.input_request.type != "fim": - return True, 'add', '' # Default to add if not FIM, since no cache check is required + return True, "add", "" # Default to add if not FIM, since no cache check is required return fim_cache.could_store_fim_request(context) # type: ignore @@ -229,7 +230,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: if not should_record: logger.info("Skipping record of context, not needed") return - if action == 'add': + if action == "add": await self.record_request(context.input_request) await self.record_outputs(context.output_responses, None) await self.record_alerts(context.alerts_raised, None) @@ -257,18 +258,61 @@ class DbReader(DbCodeGate): def __init__(self, sqlite_path: Optional[str] = None): super().__init__(sqlite_path) + async def _execute_select_pydantic_model( + self, model_type: Type[BaseModel], sql_command: TextClause + ) -> Optional[BaseModel]: + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql_command) + if not result: + return None + rows = [model_type(**row._asdict()) for row in result.fetchall() if row] + return rows + except Exception as e: + logger.error(f"Failed to select model: {model_type}.", error=str(e)) + return None + async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]: - conn = await self._async_db_engine.connect() - querier = AsyncQuerier(conn) - prompts = [prompt async for prompt in querier.get_prompt_with_outputs()] - await conn.close() + sql = text( + """ + SELECT + p.id, p.timestamp, p.provider, p.request, p.type, + o.id as output_id, + o.output, + o.timestamp as output_timestamp + FROM prompts p + LEFT JOIN outputs o ON p.id = o.prompt_id + ORDER BY o.timestamp DESC + """ + ) + prompts = await self._execute_select_pydantic_model(GetPromptWithOutputsRow, sql) return prompts async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAndOutputRow]: - conn = await self._async_db_engine.connect() - querier = AsyncQuerier(conn) - prompts = [prompt async for prompt in querier.get_alerts_with_prompt_and_output()] - await conn.close() + sql = text( + """ + SELECT + a.id, + a.prompt_id, + a.code_snippet, + a.trigger_string, + a.trigger_type, + a.trigger_category, + a.timestamp, + p.timestamp as prompt_timestamp, + p.provider, + p.request, + p.type, + o.id as output_id, + o.output, + o.timestamp as output_timestamp + FROM alerts a + LEFT JOIN prompts p ON p.id = a.prompt_id + LEFT JOIN outputs o ON p.id = o.prompt_id + ORDER BY a.timestamp DESC + """ + ) + prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql) return prompts diff --git a/src/codegate/db/fim_cache.py b/src/codegate/db/fim_cache.py index e5a488b6..482bbe7a 100644 --- a/src/codegate/db/fim_cache.py +++ b/src/codegate/db/fim_cache.py @@ -96,8 +96,9 @@ def _add_cache_entry(self, hash_key: str, context: PipelineContext): if alert.trigger_category == AlertSeverity.CRITICAL.value ] new_cache = CachedFim( - timestamp=context.input_request.timestamp, critical_alerts=critical_alerts, - initial_id=context.input_request.id + timestamp=context.input_request.timestamp, + critical_alerts=critical_alerts, + initial_id=context.input_request.id, ) self.cache[hash_key] = new_cache logger.info(f"Added cache entry for hash key: {hash_key}") @@ -115,8 +116,9 @@ def _update_cache_entry(self, hash_key: str, context: PipelineContext): ] # Update the entry in the cache with new critical alerts but keep the old timestamp. updated_cache = CachedFim( - timestamp=existing_entry.timestamp, critical_alerts=critical_alerts, - initial_id=existing_entry.initial_id + timestamp=existing_entry.timestamp, + critical_alerts=critical_alerts, + initial_id=existing_entry.initial_id, ) self.cache[hash_key] = updated_cache logger.info(f"Updated cache entry for hash key: {hash_key}") @@ -148,22 +150,25 @@ def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim def could_store_fim_request(self, context: PipelineContext): if not context.input_request: logger.warning("No input request found. Skipping creating a mapping entry") - return False, '', '' + return False, "", "" # Couldn't process the user message. Skip creating a mapping entry. message = self._extract_message_from_fim_request(context.input_request.request) if message is None: logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.") - return False, '', '' + return False, "", "" hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore cached_entry = self.cache.get(hash_key, None) - if cached_entry is None or self._is_cached_entry_old( - context, cached_entry) or self._are_new_alerts_present(context, cached_entry): + if ( + cached_entry is None + or self._is_cached_entry_old(context, cached_entry) + or self._are_new_alerts_present(context, cached_entry) + ): cached_entry = self._add_cache_entry(hash_key, context) if cached_entry is None: logger.warning("Failed to add cache entry") - return False, '', '' - return True, 'add', cached_entry.initial_id + return False, "", "" + return True, "add", cached_entry.initial_id self._update_cache_entry(hash_key, context) - return True, 'update', cached_entry.initial_id + return True, "update", cached_entry.initial_id diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index c84e207d..22859573 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -1,6 +1,3 @@ -# Code generated by sqlc. DO NOT EDIT. -# versions: -# sqlc v1.27.0 from typing import Any, Optional import pydantic @@ -38,3 +35,34 @@ class Setting(pydantic.BaseModel): llm_model: Optional[Any] system_prompt: Optional[Any] other_settings: Optional[Any] + + +# Models for select queries + + +class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel): + id: Any + prompt_id: Any + code_snippet: Optional[Any] + trigger_string: Optional[Any] + trigger_type: Any + trigger_category: Optional[Any] + timestamp: Any + prompt_timestamp: Optional[Any] + provider: Optional[Any] + request: Optional[Any] + type: Optional[Any] + output_id: Optional[Any] + output: Optional[Any] + output_timestamp: Optional[Any] + + +class GetPromptWithOutputsRow(pydantic.BaseModel): + id: Any + timestamp: Any + provider: Optional[Any] + request: Any + type: Any + output_id: Optional[Any] + output: Optional[Any] + output_timestamp: Optional[Any] diff --git a/src/codegate/db/queries.py b/src/codegate/db/queries.py deleted file mode 100644 index 4fc1fa93..00000000 --- a/src/codegate/db/queries.py +++ /dev/null @@ -1,148 +0,0 @@ -# Code generated by sqlc. DO NOT EDIT. -# versions: -# sqlc v1.27.0 -# source: queries.sql -import pydantic -from typing import Any, AsyncIterator, Iterator, Optional - -import sqlalchemy -import sqlalchemy.ext.asyncio - -from codegate.db import models - - -GET_ALERTS_WITH_PROMPT_AND_OUTPUT = """-- name: get_alerts_with_prompt_and_output \\:many -SELECT - a.id, a.prompt_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp, - p.timestamp as prompt_timestamp, - p.provider, - p.request, - p.type, - o.id as output_id, - o.output, - o.timestamp as output_timestamp -FROM alerts a -LEFT JOIN prompts p ON p.id = a.prompt_id -LEFT JOIN outputs o ON p.id = o.prompt_id -ORDER BY a.timestamp DESC -""" - - -class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel): - id: Any - prompt_id: Any - code_snippet: Optional[Any] - trigger_string: Optional[Any] - trigger_type: Any - trigger_category: Optional[Any] - timestamp: Any - prompt_timestamp: Optional[Any] - provider: Optional[Any] - request: Optional[Any] - type: Optional[Any] - output_id: Optional[Any] - output: Optional[Any] - output_timestamp: Optional[Any] - - -GET_PROMPT_WITH_OUTPUTS = """-- name: get_prompt_with_outputs \\:many -SELECT - p.id, p.timestamp, p.provider, p.request, p.type, - o.id as output_id, - o.output, - o.timestamp as output_timestamp -FROM prompts p -LEFT JOIN outputs o ON p.id = o.prompt_id -ORDER BY o.timestamp DESC -""" - - -class GetPromptWithOutputsRow(pydantic.BaseModel): - id: Any - timestamp: Any - provider: Optional[Any] - request: Any - type: Any - output_id: Optional[Any] - output: Optional[Any] - output_timestamp: Optional[Any] - - -class Querier: - def __init__(self, conn: sqlalchemy.engine.Connection): - self._conn = conn - - def get_alerts_with_prompt_and_output(self) -> Iterator[GetAlertsWithPromptAndOutputRow]: - result = self._conn.execute(sqlalchemy.text(GET_ALERTS_WITH_PROMPT_AND_OUTPUT)) - for row in result: - yield GetAlertsWithPromptAndOutputRow( - id=row[0], - prompt_id=row[1], - code_snippet=row[2], - trigger_string=row[3], - trigger_type=row[4], - trigger_category=row[5], - timestamp=row[6], - prompt_timestamp=row[7], - provider=row[8], - request=row[9], - type=row[10], - output_id=row[11], - output=row[12], - output_timestamp=row[13], - ) - - def get_prompt_with_outputs(self) -> Iterator[GetPromptWithOutputsRow]: - result = self._conn.execute(sqlalchemy.text(GET_PROMPT_WITH_OUTPUTS)) - for row in result: - yield GetPromptWithOutputsRow( - id=row[0], - timestamp=row[1], - provider=row[2], - request=row[3], - type=row[4], - output_id=row[5], - output=row[6], - output_timestamp=row[7], - ) - - -class AsyncQuerier: - def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): - self._conn = conn - - async def get_alerts_with_prompt_and_output( - self, - ) -> AsyncIterator[GetAlertsWithPromptAndOutputRow]: - result = await self._conn.stream(sqlalchemy.text(GET_ALERTS_WITH_PROMPT_AND_OUTPUT)) - async for row in result: - yield GetAlertsWithPromptAndOutputRow( - id=row[0], - prompt_id=row[1], - code_snippet=row[2], - trigger_string=row[3], - trigger_type=row[4], - trigger_category=row[5], - timestamp=row[6], - prompt_timestamp=row[7], - provider=row[8], - request=row[9], - type=row[10], - output_id=row[11], - output=row[12], - output_timestamp=row[13], - ) - - async def get_prompt_with_outputs(self) -> AsyncIterator[GetPromptWithOutputsRow]: - result = await self._conn.stream(sqlalchemy.text(GET_PROMPT_WITH_OUTPUTS)) - async for row in result: - yield GetPromptWithOutputsRow( - id=row[0], - timestamp=row[1], - provider=row[2], - request=row[3], - type=row[4], - output_id=row[5], - output=row[6], - output_timestamp=row[7], - ) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 5dd22b95..0845c0f6 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -139,11 +139,11 @@ def obfuscate(self, text: str) -> tuple[str, int]: logger.info("\nFound secrets:") for start, end, match in absolute_matches: hidden_secret = self._hide_secret(match) - self._notify_secret(match, protected_text) # Replace the secret in the text protected_text[start:end] = hidden_secret + self._notify_secret(match, protected_text) found_secrets += 1 # Log the findings logger.info( diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 737f43bc..d3652d3f 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -771,7 +771,6 @@ def _proxy_transport_write(self, data: bytes): # print(data) # print("DEBUG =================================") - def data_received(self, data: bytes) -> None: """Handle data received from target""" self._ensure_output_processor() @@ -788,7 +787,7 @@ def data_received(self, data: bytes) -> None: if header_end != -1: self.headers_sent = True # Send headers first - headers = data[: header_end] + headers = data[:header_end] # If Transfer-Encoding is not present, add it if b"Transfer-Encoding:" not in headers: diff --git a/tests/dashboard/test_post_processing.py b/tests/dashboard/test_post_processing.py index 5c387ed1..cbdb18a5 100644 --- a/tests/dashboard/test_post_processing.py +++ b/tests/dashboard/test_post_processing.py @@ -16,7 +16,7 @@ PartialConversation, QuestionAnswer, ) -from codegate.db.queries import GetPromptWithOutputsRow +from codegate.db.models import GetPromptWithOutputsRow @pytest.mark.asyncio diff --git a/tests/db/test_fim_cache.py b/tests/db/test_fim_cache.py index c6b5506e..5e2ad547 100644 --- a/tests/db/test_fim_cache.py +++ b/tests/db/test_fim_cache.py @@ -146,7 +146,7 @@ def test_are_new_alerts_present(): trigger_string=None, ) ], - initial_id='2' + initial_id="2", ) result = fim_cache._are_new_alerts_present(context, populated_cache) assert result is False @@ -156,12 +156,17 @@ def test_are_new_alerts_present(): "cached_entry, is_old", [ ( - CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), - critical_alerts=[], initial_id='1'), + CachedFim( + timestamp=datetime.now(timezone.utc) - timedelta(days=1), + critical_alerts=[], + initial_id="1", + ), True, ), - (CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[], - initial_id='2'), False), + ( + CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[], initial_id="2"), + False, + ), ], ) def test_is_cached_entry_old(cached_entry, is_old): From e3829a621efdd4878bec8e096874db05edf191bf Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Wed, 8 Jan 2025 16:52:28 +0100 Subject: [PATCH 2/2] fix: add monitoring for idle connections and close them need to find a better way --- src/codegate/providers/copilot/provider.py | 49 ++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 737f43bc..c07005f1 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -151,6 +151,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self._closing = False self.pipeline_factory = PipelineFactory(SecretsManager()) self.context_tracking: Optional[PipelineContext] = None + self.idle_timeout = 10 + self.idle_timer = None + + def _reset_idle_timer(self) -> None: + if self.idle_timer: + self.idle_timer.cancel() + self.idle_timer = asyncio.get_event_loop().call_later( + self.idle_timeout, self._handle_idle_timeout + ) + + def _handle_idle_timeout(self) -> None: + logger.warning("Idle timeout reached, closing connection") + if self.transport and not self.transport.is_closing(): + self.transport.close() def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if method == "POST" and path == "v1/engines/copilot-codex/completions": @@ -215,6 +229,7 @@ def connection_made(self, transport: asyncio.Transport) -> None: self.transport = transport self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") + self._reset_idle_timer() def get_headers_dict(self) -> Dict[str, str]: """Convert raw headers to dictionary format""" @@ -350,8 +365,10 @@ async def _forward_data_to_target(self, data: bytes) -> None: pipeline_output = pipeline_output.reconstruct() self.target_transport.write(pipeline_output) + def data_received(self, data: bytes) -> None: """Handle received data from client""" + self._reset_idle_timer() try: if not self._check_buffer_size(data): self.send_error_response(413, b"Request body too large") @@ -556,6 +573,7 @@ async def connect_to_target(self) -> None: logger.error(f"Error during TLS handshake: {e}") self.send_error_response(502, b"TLS handshake failed") + def send_error_response(self, status: int, message: bytes) -> None: """Send error response to client""" if self._closing: @@ -593,6 +611,37 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.buffer.clear() self.ssl_context = None + if self.idle_timer: + self.idle_timer.cancel() + + def eof_received(self) -> None: + print("in eof received") + """Handle connection loss""" + if self._closing: + return + + self._closing = True + logger.debug(f"EOF received from {self.peername}") + + # Close target transport if it exists and isn't already closing + if self.target_transport and not self.target_transport.is_closing(): + try: + self.target_transport.close() + except Exception as e: + logger.error(f"Error closing target transport when EOF: {e}") + + # Clear references to help with cleanup + self.transport = None + self.target_transport = None + self.buffer.clear() + self.ssl_context = None + + def pause_writing(self) -> None: + print("Transport buffer full, pausing writing") + + def resume_writing(self) -> None: + print("Transport buffer ready, resuming writing") + @classmethod async def create_proxy_server( cls, host: str, port: int, ssl_context: Optional[ssl.SSLContext] = None