Skip to content

Commit

Permalink
new_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Jun 13, 2024
1 parent 0955cdf commit aaaead0
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 44 deletions.
16 changes: 2 additions & 14 deletions impl/routes/files.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import hashlib
import time
import logging
import uuid
from datetime import datetime
from typing import Any, Dict
from uuid import uuid1

from fastapi import (
APIRouter,
Expand Down Expand Up @@ -35,22 +32,13 @@
)
from ..model.open_ai_file import OpenAIFile
from ..rate_limiter import limiter
from ..utils import generate_id_from_upload_file

router = APIRouter()

logger = logging.getLogger(__name__)


def upload_file_to_uuid_str(upload_file):
spooled_file = upload_file.file
spooled_file.seek(0)
file_data = upload_file.filename.encode('utf-8') + spooled_file.read()
sha256_hash = hashlib.sha256(file_data).hexdigest()
hash_as_uuid = uuid.UUID(sha256_hash[:32])
spooled_file.seek(0)

return str(hash_as_uuid)

@router.post(
"/files",
responses={
Expand Down Expand Up @@ -81,7 +69,7 @@ async def create_file(
else:
raise NotImplementedError("File upload is currently only supported for OpenAI")

file_id = upload_file_to_uuid_str(file)
file_id = generate_id_from_upload_file(file)
if purpose in ["auth"]:
created_at = int(time.mktime(datetime.now().timetuple()))
obj = "file"
Expand Down
7 changes: 3 additions & 4 deletions impl/routes/stateless.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""The stateless endpoints that do not depend on information from DB"""
import logging
import time
import uuid
from typing import Any, Dict

import litellm
from fastapi.encoders import jsonable_encoder
import json


from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi import APIRouter, Depends, Request
from litellm import ModelResponse
from starlette.responses import StreamingResponse, JSONResponse

Expand All @@ -28,6 +26,7 @@
from ..model.create_chat_completion_response_choices_inner import CreateChatCompletionResponseChoicesInner
from ..model.create_embedding_request import CreateEmbeddingRequest
from ..model.embedding import Embedding
from ..utils import generate_id

router = APIRouter()

Expand Down Expand Up @@ -159,7 +158,7 @@ async def chat_completion_streamer(response, model):
# Logic for streaming chat completions
if (response.__class__.__name__ == "GenerateContentResponse"):
i = 0
id = str(uuid.uuid1())
id = generate_id("cmpl")
created_time = int(time.time())
for part in response:
choices = []
Expand Down
5 changes: 2 additions & 3 deletions impl/routes_v2/assistants_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from datetime import datetime
import logging
import time
from uuid import uuid1

from fastapi import APIRouter, Body, Depends, HTTPException, Query, Path

from impl.astra_vector import CassandraClient
from impl.model_v2.create_assistant_request import CreateAssistantRequest
from impl.model_v2.modify_assistant_request import ModifyAssistantRequest
from impl.routes.utils import verify_db_client
from impl.utils import store_object, read_object, read_objects
from impl.utils import store_object, read_object, read_objects, generate_id
from openapi_server_v2.models.assistant_object import AssistantObject
from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption
from openapi_server_v2.models.delete_assistant_response import DeleteAssistantResponse
Expand Down Expand Up @@ -83,7 +82,7 @@ async def create_assistant(
create_assistant_request: CreateAssistantRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
) -> AssistantObject:
assistant_id = str(uuid1())
assistant_id = generate_id("asst")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
logging.info(f"going to create assistant with id: {assistant_id} and details {create_assistant_request}")

Expand Down
32 changes: 17 additions & 15 deletions impl/routes_v2/threads_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import re
import time
from typing import Dict, Any, Union, get_origin, Type, List, Optional
from uuid import uuid1


from fastapi import APIRouter, Body, Depends, Path, HTTPException, Query
Expand All @@ -22,7 +21,7 @@
from impl.routes_v2.assistants_v2 import get_assistant_obj
from impl.routes_v2.vector_stores import read_vsf
from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response
from impl.utils import map_model, store_object, read_object, read_objects
from impl.utils import map_model, store_object, read_object, read_objects, generate_id
from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption
from openapi_server_v2.models.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption
from openapi_server_v2.models.message_delta_object_delta_content_inner import MessageDeltaObjectDeltaContentInner
Expand Down Expand Up @@ -94,7 +93,7 @@ async def create_thread(
astradb: CassandraClient = Depends(verify_db_client),
) -> ThreadObject:
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
thread_id = str(uuid1())
thread_id = generate_id("thread")

messages = []
if create_thread_request.messages is not None:
Expand Down Expand Up @@ -193,7 +192,7 @@ async def create_message(
astradb: CassandraClient = Depends(verify_db_client),
) -> MessageObject:
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
message_id = str(uuid1())
message_id = generate_id("msg")

content = MessageContentTextObject(
text=MessageContentTextObjectText(
Expand Down Expand Up @@ -473,7 +472,8 @@ async def run_event_stream(run, message_id, astradb):
return

# this works because we make the run_step id the same as the message_id
run_step = astradb.get_run_step(run_id=run.id, id=message_id)
run_step_id = message_id.replace("msg_", "step_")
run_step = astradb.get_run_step(run_id=run.id, id=run_step_id)
if run_step is not None:
async for event in yield_events_from_object(
obj=run_step,
Expand All @@ -500,7 +500,7 @@ async def run_event_stream(run, message_id, astradb):
# tool_call_delta_object = ToolCallDeltaObject(type="tool_calls", tool_calls=retrieval_tool_call_deltas)

while run_step.status != "completed":
run_step = astradb.get_run_step(run_id=run.id, id=message_id)
run_step = astradb.get_run_step(run_id=run.id, id=run_step_id)
await asyncio.sleep(1)
tool_call_delta_object = RunStepDeltaStepDetailsToolCallsObject(type="tool_calls", tool_calls=None)
step_details = RunStepDeltaObjectDeltaStepDetails(actual_instance=tool_call_delta_object)
Expand Down Expand Up @@ -630,7 +630,7 @@ async def stream_message_events(astradb, thread_id, limit, order, after, before,
async def init_message(thread_id, assistant_id, run_id, astradb, created_at, content=None):
if content is None:
content = []
message_id = str(uuid1())
message_id = generate_id("msg")
message_obj = MessageObject(
id=message_id,
object="thread.message",
Expand Down Expand Up @@ -674,7 +674,7 @@ async def create_run(
# New Messages cannot be added to the Thread.
# New Runs cannot be created on the Thread.
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
run_id = str(uuid1())
run_id = generate_id("run")
status = "queued"

tools = create_run_request.tools
Expand Down Expand Up @@ -741,6 +741,7 @@ async def create_run(

# create run_step
# Note the run_step id is the same as the message_id
run_step_id = message_id.replace("msg_", "step_")
run_step = RunStepObject(
id=message_id,
assistant_id=assistant.id,
Expand All @@ -756,7 +757,7 @@ async def create_run(
tool_calls=[
RunStepDetailsToolCallsObjectToolCallsInner(
actual_instance=RunStepDetailsToolCallsFileSearchObject(
id=message_id,
id=run_step_id,
type="file_search",
file_search={},
)
Expand Down Expand Up @@ -806,7 +807,7 @@ async def create_run(
message_content = summarize_message_content(instructions, messages.data, False)
message = await get_chat_completion(messages=message_content, model=model, **litellm_kwargs)

tool_call_object_id = str(uuid1())
tool_call_object_id = generate_id("call")
run_tool_calls = []
if message.content is None:
for tool_call in message.tool_calls:
Expand All @@ -825,7 +826,7 @@ async def create_run(
except Exception as e:
logger.info("did not find function call in message content")
status = "completed"
message_id = str(uuid1())
message_id = generate_id("msg")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)

content = MessageContentTextObject(
Expand Down Expand Up @@ -859,7 +860,7 @@ async def create_run(
required_action = RunObjectRequiredAction(type='submit_tool_outputs', submit_tool_outputs=tool_outputs)
status = "requires_action"

message_id = str(uuid1())
message_id = generate_id("msg")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)

# groq can't handle an assistant call with no content and perplexity can't handle non-alternating user/assistant messages
Expand Down Expand Up @@ -1069,13 +1070,14 @@ async def process_rag(
completed_at = int(time.mktime(datetime.now().timetuple()))

# TODO: consider [optionally?] excluding the content payload because it can be big
run_step_id = message_id.replace("msg_", "step_")
details = RunStepObjectStepDetails(
actual_instance=RunStepDetailsToolCallsObject(
type="tool_calls",
tool_calls=[
RunStepDetailsToolCallsObjectToolCallsInner(
actual_instance=RunStepDetailsToolCallsFileSearchObject(
id=message_id,
id=run_step_id,
type="file_search",
file_search={"chunks": context_json_meta},
)
Expand All @@ -1085,7 +1087,7 @@ async def process_rag(
)

run_step = RunStepObject(
id=message_id,
id=run_step_id,
assistant_id=assistant_id,
completed_at=completed_at,
created_at=created_at,
Expand Down Expand Up @@ -1499,7 +1501,7 @@ async def submit_tool_ouputs_to_run(
)
text = message.content

id = str(uuid1())
id = generate_id("msg")
created_at = int(time.mktime(datetime.now().timetuple())*1000)


Expand Down
5 changes: 2 additions & 3 deletions impl/routes_v2/vector_stores.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datetime import datetime
import logging
import time
from uuid import uuid1

from fastapi import APIRouter, Path, Depends, Body, Query

from impl.astra_vector import CassandraClient
from impl.routes.utils import verify_db_client
from impl.utils import read_object, store_object, read_objects
from impl.utils import read_object, store_object, read_objects, generate_id
from openapi_server_v2.models.create_vector_store_file_request import CreateVectorStoreFileRequest
from openapi_server_v2.models.create_vector_store_request import CreateVectorStoreRequest
from openapi_server_v2.models.list_vector_store_files_response import ListVectorStoreFilesResponse
Expand Down Expand Up @@ -59,7 +58,7 @@ async def create_vector_store(
create_vector_store_request: CreateVectorStoreRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
) -> VectorStoreObject:
vector_store_id = str(uuid1())
vector_store_id = generate_id("vs")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)

usage_bytes = 0
Expand Down
8 changes: 4 additions & 4 deletions impl/services/chunks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import concurrent.futures
import uuid
from typing import Any, Dict, List, Optional, Tuple

import tiktoken

from impl.models import Document, DocumentChunk, DocumentChunkMetadata
#from impl.services.code_chunks import get_code_chunks
from impl.services.inference_utils import get_embeddings
from impl.utils import generate_id

# Global variables
tokenizer = tiktoken.get_encoding(
Expand Down Expand Up @@ -119,10 +119,10 @@ def create_document_chunks(
"""
# Check if the document text is empty or whitespace
if not doc.text or doc.text.isspace():
return [], doc.id or str(uuid.uuid1())
return [], doc.id or generate_id("doc")

# Generate a document id if not provided
doc_id = doc.id or str(uuid.uuid1())
doc_id = doc.id or generate_id("doc")

# Split the document text into chunks
if format in ("c", "cpp", "css", "html", "java", "js", "json", "md", "php", "py", "rb", "ts", "xml"):
Expand All @@ -145,7 +145,7 @@ def create_document_chunks(

# Assign each chunk a sequential number and create a DocumentChunk object
for i, text_chunk in enumerate(text_chunks):
chunk_id = f"{doc_id}_{i}"
chunk_id = f"chunk_{doc_id}_{i}"
doc_chunk = DocumentChunk(
id=chunk_id,
text=text_chunk,
Expand Down
Loading

0 comments on commit aaaead0

Please sign in to comment.