Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve cache system to collect the last output #497

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/codegate/dashboard/dashboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional

import structlog
from fastapi import APIRouter, Depends
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat


@dashboard_router.get("/dashboard/alerts")
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[AlertConversation]:
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
"""
Get all the messages from the database and return them as a list of conversations.
"""
Expand Down
99 changes: 69 additions & 30 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import structlog
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy import TextClause, text
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
Expand All @@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None):
current_dir = Path(__file__).parent
sqlite_path = (
current_dir.parent.parent.parent / "codegate_volume" / "db" / "codegate.db"
)
self._db_path = Path(sqlite_path).absolute()
) # type: ignore
self._db_path = Path(sqlite_path).absolute() # type: ignore
self._db_path.parent.mkdir(parents=True, exist_ok=True)
logger.debug(f"Initializing DB from path: {self._db_path}")
engine_dict = {
Expand Down Expand Up @@ -82,15 +82,15 @@ async def init_db(self):
finally:
await self._async_db_engine.dispose()

async def _insert_pydantic_model(
self, model: BaseModel, sql_insert: text
async def _execute_update_pydantic_model(
self, model: BaseModel, sql_command: TextClause #
) -> Optional[BaseModel]:
# There are create method in queries.py automatically generated by sqlc
# However, the methods are buggy for Pydancti and don't work as expected.
# Manually writing the SQL query to insert Pydantic models.
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_insert, model.model_dump())
result = await conn.execute(sql_command, model.model_dump())
row = result.first()
if row is None:
return None
Expand All @@ -99,7 +99,7 @@ async def _insert_pydantic_model(
model_class = model.__class__
return model_class(**row._asdict())
except Exception as e:
logger.error(f"Failed to insert model: {model}.", error=str(e))
logger.error(f"Failed to update model: {model}.", error=str(e))
return None

async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
Expand All @@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
RETURNING *
"""
)
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
# Uncomment to debug the recorded request
# logger.debug(f"Recorded request: {recorded_request}")
return recorded_request
return recorded_request # type: ignore

async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
async def update_request(self, initial_id: str,
prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
return None
prompt_params.id = initial_id # overwrite the initial id of the request
sql = text(
"""
UPDATE prompts
SET timestamp = :timestamp, provider = :provider, request = :request, type = :type
WHERE id = :id
RETURNING *
"""
)
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
# Uncomment to debug the recorded request
# logger.debug(f"Recorded request: {recorded_request}")
return updated_request # type: ignore

async def record_outputs(self, outputs: List[Output],
initial_id: Optional[str]) -> Optional[Output]:
if not outputs:
return

first_output = outputs[0]
# Create a single entry on DB but encode all of the chunks in the stream as a list
# of JSON objects in the field `output`
if initial_id:
first_output.prompt_id = initial_id
output_db = Output(
id=first_output.id,
prompt_id=first_output.prompt_id,
Expand All @@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
RETURNING *
"""
)
recorded_output = await self._insert_pydantic_model(output_db, sql)
recorded_output = await self._execute_update_pydantic_model(output_db, sql)
# Uncomment to debug
# logger.debug(f"Recorded output: {recorded_output}")
return recorded_output
return recorded_output # type: ignore

async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> List[Alert]:
if not alerts:
return
return []
sql = text(
"""
INSERT INTO alerts (
Expand All @@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
async with asyncio.TaskGroup() as tg:
for alert in alerts:
try:
result = tg.create_task(self._insert_pydantic_model(alert, sql))
if initial_id:
alert.prompt_id = initial_id
result = tg.create_task(self._execute_update_pydantic_model(alert, sql))
alerts_tasks.append(result)
except Exception as e:
logger.error(f"Failed to record alert: {alert}.", error=str(e))
Expand All @@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
# logger.debug(f"Recorded alerts: {recorded_alerts}")
return recorded_alerts

def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
"""Check if the context should be recorded in DB"""
def _should_record_context(self, context: Optional[PipelineContext]) -> tuple:
"""Check if the context should be recorded in DB and determine the action."""
if context is None or context.metadata.get("stored_in_db", False):
return False
return False, None, None

if not context.input_request:
logger.warning("No input request found. Skipping recording context.")
return False
return False, None, None

# If it's not a FIM prompt, we don't need to check anything else.
if context.input_request.type != "fim":
return True
return True, 'add', '' # Default to add if not FIM, since no cache check is required

return fim_cache.could_store_fim_request(context)
return fim_cache.could_store_fim_request(context) # type: ignore

async def record_context(self, context: Optional[PipelineContext]) -> None:
try:
if not self._should_record_context(context):
if not context:
logger.info("No context provided, skipping")
return
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses)
await self.record_alerts(context.alerts_raised)
context.metadata["stored_in_db"] = True
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
should_record, action, initial_id = self._should_record_context(context)
if not should_record:
logger.info("Skipping record of context, not needed")
return
if action == 'add':
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses, None)
await self.record_alerts(context.alerts_raised, None)
context.metadata["stored_in_db"] = True
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
else:
# update them
await self.update_request(initial_id, context.input_request)
await self.record_outputs(context.output_responses, initial_id)
await self.record_alerts(context.alerts_raised, initial_id)
context.metadata["stored_in_db"] = True
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

Expand Down
69 changes: 51 additions & 18 deletions src/codegate/db/fim_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CachedFim(BaseModel):

timestamp: datetime.datetime
critical_alerts: List[Alert]
initial_id: str


class FimCache:
Expand Down Expand Up @@ -86,16 +87,42 @@ def _calculate_hash_key(self, message: str, provider: str) -> str:

def _add_cache_entry(self, hash_key: str, context: PipelineContext):
"""Add a new cache entry"""
if not context.input_request:
logger.warning("No input request found. Skipping creating a mapping entry")
return
critical_alerts = [
alert
for alert in context.alerts_raised
if alert.trigger_category == AlertSeverity.CRITICAL.value
]
new_cache = CachedFim(
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts,
initial_id=context.input_request.id
)
self.cache[hash_key] = new_cache
logger.info(f"Added cache entry for hash key: {hash_key}")
return self.cache[hash_key]

def _update_cache_entry(self, hash_key: str, context: PipelineContext):
"""Update an existing cache entry without changing the timestamp."""
existing_entry = self.cache.get(hash_key)
if existing_entry is not None:
# Update critical alerts while retaining the original timestamp.
critical_alerts = [
alert
for alert in context.alerts_raised
if alert.trigger_category == AlertSeverity.CRITICAL.value
]
# Update the entry in the cache with new critical alerts but keep the old timestamp.
updated_cache = CachedFim(
timestamp=existing_entry.timestamp, critical_alerts=critical_alerts,
initial_id=existing_entry.initial_id
)
self.cache[hash_key] = updated_cache
logger.info(f"Updated cache entry for hash key: {hash_key}")
else:
# Log a warning if trying to update a non-existent entry - ideally should not happen.
logger.warning(f"Attempted to update non-existent cache entry for hash key: {hash_key}")

def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
"""Check if there are new alerts present"""
Expand All @@ -108,29 +135,35 @@ def _are_new_alerts_present(self, context: PipelineContext, cached_entry: Cached

def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
"""Check if the cached entry is old"""
if not context.input_request:
logger.warning("No input request found. Skipping checking if the cache entry is old")
return False
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime
config = Config.get_config()
if config is None:
logger.warning("No configuration found. Skipping checking if the cache entry is old")
return True
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime # type: ignore

def could_store_fim_request(self, context: PipelineContext):
if not context.input_request:
logger.warning("No input request found. Skipping creating a mapping entry")
return False, '', ''
# Couldn't process the user message. Skip creating a mapping entry.
message = self._extract_message_from_fim_request(context.input_request.request)
if message is None:
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
return False
return False, '', ''

hash_key = self._calculate_hash_key(message, context.input_request.provider)
hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore
cached_entry = self.cache.get(hash_key, None)
if cached_entry is None:
self._add_cache_entry(hash_key, context)
return True

if self._is_cached_entry_old(context, cached_entry):
self._add_cache_entry(hash_key, context)
return True

if self._are_new_alerts_present(context, cached_entry):
self._add_cache_entry(hash_key, context)
return True

logger.debug(f"FIM entry already in cache: {hash_key}.")
return False
if cached_entry is None or self._is_cached_entry_old(
context, cached_entry) or self._are_new_alerts_present(context, cached_entry):
cached_entry = self._add_cache_entry(hash_key, context)
if cached_entry is None:
logger.warning("Failed to add cache entry")
return False, '', ''
return True, 'add', cached_entry.initial_id

self._update_cache_entry(hash_key, context)
return True, 'update', cached_entry.initial_id
11 changes: 7 additions & 4 deletions tests/db/test_fim_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_extract_message_from_fim_request(test_request, expected_result_content)

def test_are_new_alerts_present():
fim_cache = FimCache()
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[])
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[], initial_id="1")
context = PipelineContext()
context.alerts_raised = [mock.MagicMock(trigger_category=AlertSeverity.CRITICAL.value)]
result = fim_cache._are_new_alerts_present(context, cached_entry)
Expand All @@ -146,6 +146,7 @@ def test_are_new_alerts_present():
trigger_string=None,
)
],
initial_id='2'
)
result = fim_cache._are_new_alerts_present(context, populated_cache)
assert result is False
Expand All @@ -155,15 +156,17 @@ def test_are_new_alerts_present():
"cached_entry, is_old",
[
(
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), critical_alerts=[]),
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1),
critical_alerts=[], initial_id='1'),
True,
),
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[]), False),
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[],
initial_id='2'), False),
],
)
def test_is_cached_entry_old(cached_entry, is_old):
context = PipelineContext()
context.add_input_request("test", True, "test_provider")
context.add_input_request("test", True, "test_provider") # type: ignore
fim_cache = FimCache()
result = fim_cache._is_cached_entry_old(context, cached_entry)
assert result == is_old
Loading