Skip to content

Commit

Permalink
Merge pull request #15 from datastax/run-step-retrieval-details
Browse files Browse the repository at this point in the history
Retrieval Details
  • Loading branch information
phact authored Mar 20, 2024
2 parents ac16812 + c1f140d commit 4a72029
Show file tree
Hide file tree
Showing 9 changed files with 446 additions and 1,417 deletions.
193 changes: 179 additions & 14 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dict_factory,
named_tuple_factory, PreparedStatement,
)
from openai.types.beta.threads.runs import RunStep, ToolCallsStepDetails
from pydantic import BaseModel, Field

from impl.model.assistant_object import AssistantObject
Expand Down Expand Up @@ -572,6 +573,29 @@ def create_table(self):
); """
)

self.session.execute(
f"""create table if not exists {CASSANDRA_KEYSPACE}.run_steps(
id text,
assistant_id text,
cancelled_at timestamp,
completed_at timestamp,
created_at timestamp,
expired_at timestamp,
failed_at timestamp,
last_error text,
metadata map<text, text>,
object text,
run_id text,
status text,
step_details text,
thread_id text,
type text,
usage text,
PRIMARY KEY((run_id), id)
); """
)


statement = SimpleStatement(
f"CREATE CUSTOM INDEX IF NOT EXISTS ON {CASSANDRA_KEYSPACE}.file_chunks (embedding) USING 'StorageAttachedIndex';",
consistency_level=ConsistencyLevel.QUORUM,
Expand All @@ -593,6 +617,73 @@ def delete_assistant(self, id):
self.session.execute(bound)
return True

def get_run_step(self, id, run_id):
query_string = f"""
SELECT * FROM {CASSANDRA_KEYSPACE}.run_steps WHERE id = ? and run_id = ?;
"""

statement = self.session.prepare(query_string)
statement.consistency_level = ConsistencyLevel.QUORUM
self.session.row_factory = dict_factory
bound = statement.bind(
(
id,
run_id,
)
)
rows = self.session.execute(bound)
result = [dict(row) for row in rows]
if result is None or len(result) == 0:
return None
json_rows = result[0]
self.session.row_factory = named_tuple_factory

metadata = json_rows["metadata"]
if metadata is None:
metadata = {}

cancelled_at = json_rows["cancelled_at"]
completed_at = json_rows["completed_at"]
created_at = json_rows["created_at"]
expired_at = json_rows["expired_at"]
failed_at = json_rows["failed_at"]

if cancelled_at is not None:
cancelled_at = int(cancelled_at.timestamp() * 1000)
if completed_at is not None:
completed_at = int(completed_at.timestamp() * 1000)
if created_at is not None:
created_at = int(created_at.timestamp() * 1000)
if expired_at is not None:
expired_at = int(expired_at.timestamp() * 1000)
if failed_at is not None:
failed_at = int(failed_at.timestamp() * 1000)

try:
step_details = ToolCallsStepDetails.parse_raw(json_rows["step_details"])
run_step = RunStep(
id=json_rows["id"],
assistant_id=json_rows["assistant_id"],
cancelled_at=cancelled_at,
completed_at=completed_at,
created_at=created_at,
expired_at=expired_at,
failed_at=failed_at,
last_error=json_rows["last_error"],
metadata=metadata,
object=json_rows["object"],
run_id=json_rows["run_id"],
status=json_rows["status"],
step_details=step_details,
thread_id=json_rows["thread_id"],
type=json_rows["type"],
usage=json_rows["usage"],
)
return run_step
except Exception as e:
logger.error(f"Error parsing run step: {e}")
raise e

def get_run(self, id, thread_id):
query_string = f"""
SELECT * FROM {CASSANDRA_KEYSPACE}.runs WHERE id = ? and thread_id = ?;
Expand Down Expand Up @@ -758,6 +849,71 @@ def update_run_status(self, id, thread_id, status):
self.session.execute(bound)
return True


def upsert_run_step(self, run_step : RunStep):
query_string = f"""insert into {CASSANDRA_KEYSPACE}.run_steps(
id,
assistant_id,
cancelled_at,
completed_at,
created_at,
expired_at,
failed_at,
last_error,
metadata,
object,
run_id,
status,
step_details,
thread_id,
type,
usage
) VALUES (
?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?
);"""
statement = self.session.prepare(query_string)
statement.consistency_level = ConsistencyLevel.QUORUM

id = run_step.id
assistant_id = run_step.assistant_id
cancelled_at = run_step.cancelled_at
completed_at = run_step.completed_at
created_at = run_step.created_at
expired_at = run_step.expired_at
failed_at = run_step.failed_at
last_error = run_step.last_error
metadata = run_step.metadata
object = run_step.object
run_id = run_step.run_id
status = run_step.status
step_details = run_step.step_details
thread_id = run_step.thread_id
type = run_step.type
usage = run_step.usage

self.session.execute(
statement,
(
id,
assistant_id,
cancelled_at,
completed_at,
created_at,
expired_at,
failed_at,
last_error,
metadata,
object,
run_id,
status,
step_details.json(),
thread_id,
type,
usage
),
)


def upsert_run(
self,
id,
Expand Down Expand Up @@ -1550,33 +1706,42 @@ def annSearch(
):
litellm_kwargs_embedding = litellm_kwargs.copy()
litellm_kwargs_embedding["api_key"] = embedding_api_key
embeddings.append(get_embeddings([search_string], model=embedding_model, **litellm_kwargs_embedding))
# TODO maybe support scores one day?
# queryString += f"similarity_cosine(?, {column['column_name']}) as {column['column_name']}_score, "
else:
embeddings.append(get_embeddings([search_string], model=embedding_model, **litellm_kwargs_embedding)[0])
queryString += f"similarity_cosine(?, {column['column_name']}) as score, "
elif 'embedding' not in column['column_name'] and column['column_name'] != 'created_at':
queryString += f"{column['column_name']}, "
queryString = queryString[:-2]

queryString += f""" FROM {CASSANDRA_KEYSPACE}.{table} """
if len(partitions) > 0:
queryString += f"WHERE file_id in ("
for partition in partitions:
queryString += f"'{partition}',"
queryString = queryString[:-1]
queryString += f") "
queryString += f"ORDER BY "
if len(partitions) > 1:
return self.handle_multiple_partitions(embeddings, limit, queryString, vector_index_column, partitions)
else:
return self.finish_ann_query_and_get_json(embeddings, limit, queryString, vector_index_column, partitions)

# TODO: make this async and or fix the data model
def handle_multiple_partitions(self, embeddings, limit, queryString, vector_index_column, partitions):
json_rows = []
for partition in partitions:
ann_results = self.finish_ann_query_and_get_json(embeddings, limit, queryString, vector_index_column, [partition])
json_rows.append(ann_results[0])
#sort json_rows by score
json_rows = sorted(json_rows, key=lambda x: x["score"], reverse=True)
#trim limit
json_rows = json_rows[:limit]
return json_rows

def finish_ann_query_and_get_json(self, embeddings, limit, queryString, vector_index_column, partitions):
queryString += f"WHERE file_id = '{partitions[0]}' "
queryString += f"ORDER BY "
queryString += f"""
{vector_index_column} ann of ?
"""

# TODO make limit configurable
queryString += f"LIMIT {limit}"

statement = self.session.prepare(queryString)
statement.retry_policy = VectorRetryPolicy()
statement.consistency_level = ConsistencyLevel.LOCAL_ONE
boundStatement = statement.bind(embeddings[0])
boundStatement = statement.bind([embeddings[0], embeddings[0]])
self.session.row_factory = dict_factory
json_rows = self.execute_and_get_json(boundStatement, vector_index_column)
return json_rows
Expand Down
3 changes: 2 additions & 1 deletion impl/model/assistant_object.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Annotated, List

from pydantic import Field
from pydantic import Field, StrictStr

from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from openapi_server.models.assistant_object import AssistantObject as AssistantObjectGenerated


class AssistantObject(AssistantObjectGenerated):
tools: Annotated[List[AssistantObjectToolsInner], Field(max_length=20)] = Field(description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.")
file_ids: Annotated[List[StrictStr], Field(max_length=1000)] = Field(description="A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order. ")
5 changes: 3 additions & 2 deletions impl/model/create_assistant_request.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Annotated, List

from pydantic import Field
from pydantic import Field, StrictStr

from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from openapi_server.models.create_assistant_request import CreateAssistantRequest as CreateAssistantRequestGenerated


class CreateAssistantRequest(CreateAssistantRequestGenerated):
tools: Optional[Annotated[List[AssistantObjectToolsInner], Field(max_length=128)]] = Field(default=None, description="assistant_tools_param_description")
tools: Optional[Annotated[List[AssistantObjectToolsInner], Field(max_length=128)]] = Field(default=None, description="assistant_tools_param_description")
file_ids: Optional[Annotated[List[StrictStr], Field(max_length=1000)]] = Field(default=None, description="assistant_file_param_description")
32 changes: 32 additions & 0 deletions impl/model/list_messages_stream_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from typing import List, Dict


from pydantic import Field

from impl.model.message_stream_response_object import MessageStreamResponseObject
from openapi_server.models.list_messages_stream_response import ListMessagesStreamResponse as ListMessagesStreamResponseGenerated


class ListMessagesStreamResponse(ListMessagesStreamResponseGenerated):
data: List[MessageStreamResponseObject] = Field(description="The streamed chunks of messages, each representing a part of a message or a full message.")

@classmethod
def from_dict(cls, obj: Dict) -> Self:
"""Create an instance of ListMessagesStreamResponse from a dict"""
if obj is None:
return None

if not isinstance(obj, dict):
return cls.model_validate(obj)

_obj = cls.model_validate({
"object": obj.get("object"),
"data": [MessageStreamResponseObject.from_dict(_item) for _item in obj.get("data")] if obj.get("data") is not None else None,
"first_id": obj.get("first_id"),
"last_id": obj.get("last_id")
})
return _obj
5 changes: 3 additions & 2 deletions impl/model/message_object.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List
from typing import List, Annotated

from pydantic import Field
from pydantic import Field, StrictStr

from openapi_server.models.message_content_text_object import MessageContentTextObject
from openapi_server.models.message_object import MessageObject as MessageObjectGenerated

class MessageObject(MessageObjectGenerated):
content: List[MessageContentTextObject]
file_ids: Annotated[List[StrictStr], Field(max_length=1000)] = Field(description="A list of [file](/docs/api-reference/files) IDs that the assistant should use. Useful for tools like retrieval and code_interpreter that can access files. A maximum of 10 files can be attached to a message.")
9 changes: 9 additions & 0 deletions impl/model/message_stream_response_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Annotated, List

from pydantic import StrictStr, Field

from openapi_server.models.message_stream_response_object import MessageStreamResponseObject as MessageStreamResponseObjectGenerated


class MessageStreamResponseObject(MessageStreamResponseObjectGenerated):
file_ids: Annotated[List[StrictStr], Field(max_length=1000)] = Field(description="A list of [file](/docs/api-reference/files) IDs that the assistant should use. Useful for tools like retrieval and code_interpreter that can access files. A maximum of 10 files can be attached to a message.")
Loading

0 comments on commit 4a72029

Please sign in to comment.