Skip to content

Commit

Permalink
Merge branch 'main' into fix-copilot-hang
Browse files Browse the repository at this point in the history
  • Loading branch information
lukehinds authored Jan 9, 2025
2 parents eaf9dbd + 1d0c5f5 commit 7008330
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 241 deletions.
4 changes: 3 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ litellm = "^1.57.1"
llama_cpp_python = ">=0.3.2"
cryptography = "^44.0.0"
sqlalchemy = "^2.0.28"
greenlet = "^3.0.3"
aiosqlite = "^0.20.0"
ollama = ">=0.4.4"
pydantic-settings = "^2.7.1"
Expand All @@ -27,8 +26,8 @@ tree-sitter-java = ">=0.23.5"
tree-sitter-javascript = ">=0.23.1"
tree-sitter-python = ">=0.23.6"
tree-sitter-rust = ">=0.23.2"

sqlite-vec-sl-tmp = "^0.0.4"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.0"
pytest-cov = ">=4.1.0"
Expand Down
24 changes: 0 additions & 24 deletions sql/queries/queries.sql

This file was deleted.

20 changes: 0 additions & 20 deletions sqlc.yaml

This file was deleted.

7 changes: 4 additions & 3 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PartialConversation,
QuestionAnswer,
)
from codegate.db.queries import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -183,7 +183,7 @@ async def parse_get_prompt_with_output(

def parse_question_answer(input_text: str) -> str:
# given a string, detect if we have a pattern of "Context: xxx \n\nQuery: xxx" and strip it
pattern = r'^Context:.*?\n\n\s*Query:\s*(.*)$'
pattern = r"^Context:.*?\n\n\s*Query:\s*(.*)$"

# Search using the regex pattern
match = re.search(pattern, input_text, re.DOTALL)
Expand Down Expand Up @@ -226,7 +226,8 @@ async def match_conversations(
if partial_conversation.question_answer.answer is not None:
first_partial_conversation = partial_conversation
partial_conversation.question_answer.question.message = parse_question_answer(
partial_conversation.question_answer.question.message)
partial_conversation.question_answer.question.message
)
questions_answers.append(partial_conversation.question_answer)

# only add conversation if we have some answers
Expand Down
88 changes: 66 additions & 22 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import asyncio
import json
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Type

import structlog
from pydantic import BaseModel
from sqlalchemy import TextClause, text
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
from codegate.db.models import Alert, Output, Prompt
from codegate.db.queries import (
AsyncQuerier,
from codegate.db.models import (
Alert,
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
Output,
Prompt,
)
from codegate.pipeline.base import PipelineContext

Expand Down Expand Up @@ -83,11 +84,9 @@ async def init_db(self):
await self._async_db_engine.dispose()

async def _execute_update_pydantic_model(
self, model: BaseModel, sql_command: TextClause #
self, model: BaseModel, sql_command: TextClause
) -> Optional[BaseModel]:
# There are create method in queries.py automatically generated by sqlc
# However, the methods are buggy for Pydancti and don't work as expected.
# Manually writing the SQL query to insert Pydantic models.
"""Execute an update or insert command for a Pydantic model."""
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command, model.model_dump())
Expand Down Expand Up @@ -117,8 +116,9 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
# logger.debug(f"Recorded request: {recorded_request}")
return recorded_request # type: ignore

async def update_request(self, initial_id: str,
prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
async def update_request(
self, initial_id: str, prompt_params: Optional[Prompt] = None
) -> Optional[Prompt]:
if prompt_params is None:
return None
prompt_params.id = initial_id # overwrite the initial id of the request
Expand All @@ -135,8 +135,9 @@ async def update_request(self, initial_id: str,
# logger.debug(f"Recorded request: {recorded_request}")
return updated_request # type: ignore

async def record_outputs(self, outputs: List[Output],
initial_id: Optional[str]) -> Optional[Output]:
async def record_outputs(
self, outputs: List[Output], initial_id: Optional[str]
) -> Optional[Output]:
if not outputs:
return

Expand Down Expand Up @@ -216,7 +217,7 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> tuple:

# If it's not a FIM prompt, we don't need to check anything else.
if context.input_request.type != "fim":
return True, 'add', '' # Default to add if not FIM, since no cache check is required
return True, "add", "" # Default to add if not FIM, since no cache check is required

return fim_cache.could_store_fim_request(context) # type: ignore

Expand All @@ -229,7 +230,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
if not should_record:
logger.info("Skipping record of context, not needed")
return
if action == 'add':
if action == "add":
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses, None)
await self.record_alerts(context.alerts_raised, None)
Expand Down Expand Up @@ -257,18 +258,61 @@ class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def _execute_select_pydantic_model(
self, model_type: Type[BaseModel], sql_command: TextClause
) -> Optional[BaseModel]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command)
if not result:
return None
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
return rows
except Exception as e:
logger.error(f"Failed to select model: {model_type}.", error=str(e))
return None

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
conn = await self._async_db_engine.connect()
querier = AsyncQuerier(conn)
prompts = [prompt async for prompt in querier.get_prompt_with_outputs()]
await conn.close()
sql = text(
"""
SELECT
p.id, p.timestamp, p.provider, p.request, p.type,
o.id as output_id,
o.output,
o.timestamp as output_timestamp
FROM prompts p
LEFT JOIN outputs o ON p.id = o.prompt_id
ORDER BY o.timestamp DESC
"""
)
prompts = await self._execute_select_pydantic_model(GetPromptWithOutputsRow, sql)
return prompts

async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAndOutputRow]:
conn = await self._async_db_engine.connect()
querier = AsyncQuerier(conn)
prompts = [prompt async for prompt in querier.get_alerts_with_prompt_and_output()]
await conn.close()
sql = text(
"""
SELECT
a.id,
a.prompt_id,
a.code_snippet,
a.trigger_string,
a.trigger_type,
a.trigger_category,
a.timestamp,
p.timestamp as prompt_timestamp,
p.provider,
p.request,
p.type,
o.id as output_id,
o.output,
o.timestamp as output_timestamp
FROM alerts a
LEFT JOIN prompts p ON p.id = a.prompt_id
LEFT JOIN outputs o ON p.id = o.prompt_id
ORDER BY a.timestamp DESC
"""
)
prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql)
return prompts


Expand Down
27 changes: 16 additions & 11 deletions src/codegate/db/fim_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def _add_cache_entry(self, hash_key: str, context: PipelineContext):
if alert.trigger_category == AlertSeverity.CRITICAL.value
]
new_cache = CachedFim(
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts,
initial_id=context.input_request.id
timestamp=context.input_request.timestamp,
critical_alerts=critical_alerts,
initial_id=context.input_request.id,
)
self.cache[hash_key] = new_cache
logger.info(f"Added cache entry for hash key: {hash_key}")
Expand All @@ -115,8 +116,9 @@ def _update_cache_entry(self, hash_key: str, context: PipelineContext):
]
# Update the entry in the cache with new critical alerts but keep the old timestamp.
updated_cache = CachedFim(
timestamp=existing_entry.timestamp, critical_alerts=critical_alerts,
initial_id=existing_entry.initial_id
timestamp=existing_entry.timestamp,
critical_alerts=critical_alerts,
initial_id=existing_entry.initial_id,
)
self.cache[hash_key] = updated_cache
logger.info(f"Updated cache entry for hash key: {hash_key}")
Expand Down Expand Up @@ -148,22 +150,25 @@ def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim
def could_store_fim_request(self, context: PipelineContext):
if not context.input_request:
logger.warning("No input request found. Skipping creating a mapping entry")
return False, '', ''
return False, "", ""
# Couldn't process the user message. Skip creating a mapping entry.
message = self._extract_message_from_fim_request(context.input_request.request)
if message is None:
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
return False, '', ''
return False, "", ""

hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore
cached_entry = self.cache.get(hash_key, None)
if cached_entry is None or self._is_cached_entry_old(
context, cached_entry) or self._are_new_alerts_present(context, cached_entry):
if (
cached_entry is None
or self._is_cached_entry_old(context, cached_entry)
or self._are_new_alerts_present(context, cached_entry)
):
cached_entry = self._add_cache_entry(hash_key, context)
if cached_entry is None:
logger.warning("Failed to add cache entry")
return False, '', ''
return True, 'add', cached_entry.initial_id
return False, "", ""
return True, "add", cached_entry.initial_id

self._update_cache_entry(hash_key, context)
return True, 'update', cached_entry.initial_id
return True, "update", cached_entry.initial_id
34 changes: 31 additions & 3 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.27.0
from typing import Any, Optional

import pydantic
Expand Down Expand Up @@ -38,3 +35,34 @@ class Setting(pydantic.BaseModel):
llm_model: Optional[Any]
system_prompt: Optional[Any]
other_settings: Optional[Any]


# Models for select queries


class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel):
id: Any
prompt_id: Any
code_snippet: Optional[Any]
trigger_string: Optional[Any]
trigger_type: Any
trigger_category: Optional[Any]
timestamp: Any
prompt_timestamp: Optional[Any]
provider: Optional[Any]
request: Optional[Any]
type: Optional[Any]
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]


class GetPromptWithOutputsRow(pydantic.BaseModel):
id: Any
timestamp: Any
provider: Optional[Any]
request: Any
type: Any
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]
Loading

0 comments on commit 7008330

Please sign in to comment.