Skip to content

Commit

Permalink
Merge pull request #11 from datastax/multi-event-streaming
Browse files Browse the repository at this point in the history
Multi event streaming
  • Loading branch information
phact authored Mar 16, 2024
2 parents a9f5b1b + 6944125 commit 22d200a
Show file tree
Hide file tree
Showing 221 changed files with 18,899 additions and 3,494 deletions.
102 changes: 102 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,40 @@ jobs:
- name: run tests
run: |
poetry run pytest -s --disable-warnings tests/http/
run-async-http-tests:
runs-on: ubuntu-latest
name: run async http tests
env:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
base_url: ${{ secrets.BASE_URL }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}

steps:
- name: Git checkout
uses: actions/checkout@v3
- name: Set up Python 3.10.12
uses: actions/setup-python@v2
with:
python-version: '3.10.12'
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Check Poetry Version
run: poetry --version
- name: Configure Poetry to Use Python 3.10.12
run: poetry env use python3.10
- name: get dependencies
run: |
poetry install
- name: run tests
run: |
poetry run pytest -s --disable-warnings tests/async_http/
run-openai-sdk-tests:
runs-on: ubuntu-latest
name: run openai-sdk tests
Expand Down Expand Up @@ -278,3 +312,71 @@ jobs:
- name: run tests
run: |
poetry run pytest -s --disable-warnings tests/streaming-assistants/test_run_retreival.py
run-streaming-assistants-tests-streaming-run-retrieval:
runs-on: ubuntu-latest
name: run streaming-assistants streaming run retrieval tests
env:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
base_url: ${{ secrets.BASE_URL }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}

steps:
- name: Git checkout
uses: actions/checkout@v3
- name: Set up Python 3.10.12
uses: actions/setup-python@v2
with:
python-version: '3.10.12'
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Check Poetry Version
run: poetry --version
- name: Configure Poetry to Use Python 3.10.12
run: poetry env use python3.10
- name: get dependencies
run: |
poetry install
- name: run tests
run: |
poetry run pytest -s --disable-warnings tests/streaming-assistants/test_streaming_run_retrieval.py
run-streaming-assistants-tests-streaming-run:
runs-on: ubuntu-latest
name: run streaming-assistants streaming run tests
env:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
base_url: ${{ secrets.BASE_URL }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}

steps:
- name: Git checkout
uses: actions/checkout@v3
- name: Set up Python 3.10.12
uses: actions/setup-python@v2
with:
python-version: '3.10.12'
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Check Poetry Version
run: poetry --version
- name: Configure Poetry to Use Python 3.10.12
run: poetry env use python3.10
- name: get dependencies
run: |
poetry install
- name: run tests
run: |
poetry run pytest -s --disable-warnings tests/streaming-assistants/test_streaming_run.py
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.1.1
v0.1.2
110 changes: 110 additions & 0 deletions examples/python/streaming_runs/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import time
from openai import OpenAI
from dotenv import load_dotenv
from openai.types.beta.assistant_stream_event import ThreadMessageDelta
from streaming_assistants import patch
from openai.lib.streaming import AssistantEventHandler
from typing_extensions import override


load_dotenv("./.env")
load_dotenv("../../../.env")

def run_with_assistant(assistant, client):
print(f"created assistant: {assistant.name}")
print("Uploading file:")
# Upload the file
file = client.files.create(
file=open(
"./examples/python/language_models_are_unsupervised_multitask_learners.pdf",
"rb",
),
purpose="assistants",
)
print("adding file id to assistant")
# Update Assistant
assistant = client.beta.assistants.update(
assistant.id,
tools=[{"type": "retrieval"}],
file_ids=[file.id],
)
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences."
print("creating persistent thread and message")
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=user_message
)
print(f"> {user_message}")

class EventHandler(AssistantEventHandler):
@override
def on_text_delta(self, delta, snapshot):
# Increment the counter each time the method is called
print(delta.value, end="", flush=True)

print(f"creating run")
with client.beta.threads.runs.create_and_stream(
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=EventHandler(),
) as stream:
for part in stream:
if not isinstance(part, ThreadMessageDelta):
print(f'received event: {part}\n')

print("\n")


client = patch(OpenAI())

instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond and share the exact snippets from the file at the end of your response."

model = "gpt-3.5-turbo"
name = f"{model} Math Tutor"

gpt3_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(gpt3_assistant, client)

model = "cohere/command"
name = f"{model} Math Tutor"

cohere_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(cohere_assistant, client)

model = "perplexity/mixtral-8x7b-instruct"
name = f"{model} Math Tutor"

perplexity_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(perplexity_assistant, client)

model = "anthropic.claude-v2"
name = f"{model} Math Tutor"

claude_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(claude_assistant, client)

model = "gemini/gemini-pro"
name = f"{model} Math Tutor"

gemini_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(gemini_assistant, client)
23 changes: 12 additions & 11 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
)
from pydantic import BaseModel, Field

from impl.model.assistant_object import AssistantObject
from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from impl.model.message_object import MessageObject
from impl.model.open_ai_file import OpenAIFile
from impl.model.run_object import RunObject
from impl.models import (
DocumentChunk,
DocumentChunkMetadata,
Expand All @@ -34,16 +39,11 @@
QueryWithEmbedding,
)
from impl.services.inference_utils import get_embeddings
from openapi_server.models.assistant_object import AssistantObject
from openapi_server.models.message_content_text_object import MessageContentTextObject
from openapi_server.models.message_content_text_object_text import MessageContentTextObjectText
from openapi_server.models.run_object_required_action import RunObjectRequiredAction
from openapi_server.models.message_object import MessageObject
from openapi_server.models.message_object_content_inner import MessageObjectContentInner
from openapi_server.models.open_ai_file import OpenAIFile
from openapi_server.models.run_object import RunObject
from openapi_server.models.thread_object import ThreadObject
from openapi_server.models.assistant_object_tools_inner import AssistantObjectToolsInner



# Create a logger for this module.
Expand Down Expand Up @@ -858,6 +858,7 @@ def upsert_run(
tools=tools,
file_ids=file_ids,
metadata=metadata,
usage=None,
)

def upsert_message(
Expand Down Expand Up @@ -944,8 +945,8 @@ def get_message(self, thread_id, message_id):
if file_ids is None:
file_ids = []

created_at = row["created_at"].timestamp() * 1000
return MessageObject(
created_at = int(row["created_at"].timestamp() * 1000)
message_object = MessageObject(
id=row['id'],
object=row['object'],
created_at=created_at,
Expand All @@ -957,12 +958,13 @@ def get_message(self, thread_id, message_id):
file_ids=file_ids,
metadata=metadata
)
return message_object

def upsert_content_only_file(
self, id, created_at, object, purpose, filename, format, bytes, content, **litellm_kwargs,
):
self.upsert_chunks_content_only(id, content, created_at)
status = "success"
status = "uploaded"
query_string = f"""insert into {CASSANDRA_KEYSPACE}.files (
id,
object,
Expand Down Expand Up @@ -998,7 +1000,7 @@ def upsert_file(
self, id, created_at, object, purpose, filename, format, bytes, chunks, model, **litellm_kwargs,
):
self.upsert_chunks(chunks, model, **litellm_kwargs)
status = "success"
status = "processed"

query_string = f"""insert into {CASSANDRA_KEYSPACE}.files (
id,
Expand Down Expand Up @@ -1157,7 +1159,6 @@ def upsert_assistant(

def __del__(self):
# close the connection when the client is destroyed
logger.info("shutdown")
self.session.shutdown()

# TODO: make these async
Expand Down
38 changes: 38 additions & 0 deletions impl/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
import logging

logger = logging.getLogger(__name__)
background_task_set = set()

event_loop = asyncio.new_event_loop()

async def add_background_task(function, run_id, thread_id, astradb):
logger.debug("Creating background task")
task = asyncio.create_task(
function, name=run_id
)
background_task_set.add(task)
task.add_done_callback(lambda t: on_task_completion(t, astradb=astradb, run_id=run_id, thread_id=thread_id))


def on_task_completion(task, astradb, run_id, thread_id):
background_task_set.remove(task)
logger.debug(f"Task stopped for run_id: {run_id} and thread_id: {thread_id}")

if task.cancelled():
logger.warning(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
return
try:
exception = task.exception()
if exception is not None:
logger.warning(f"Task raised an exception, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
logger.error(exception)
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
raise exception
else:
logger.debug(f"Task completed successfully for run_id: {run_id} and thread_id: {thread_id}")
except asyncio.CancelledError:
logger.warning(f"why wasn't this caught in task.cancelled()")
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
24 changes: 18 additions & 6 deletions impl/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Callable, Sequence, Union

Expand All @@ -10,11 +11,12 @@
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import Info

from impl.background import background_task_set
from impl.routes import assistants, files, health, stateless, threads

# Configure logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)',
format='%(asctime)s - %(levelname)s - %(message)s (%(module)s:%(filename)s:%(lineno)d)',
datefmt='%Y-%m-%d %H:%M:%S')

logger = logging.getLogger('cassandra')
Expand All @@ -24,12 +26,22 @@
logger = logging.getLogger(__name__)

app = FastAPI(
# TODO: Change these?
title="OpenAI API",
description="The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details.",
title="Astra Assistants API",
description="Drop in replacement for OpenAI Assistants API. .",
version="2.0.0",
)

@app.on_event("shutdown")
async def shutdown_event():
logger.info("shutting down server")
for task in background_task_set:
task.cancel()
try:
await task # Give the task a chance to finish
except asyncio.CancelledError:
pass # Handle cancellation if needed


app.include_router(assistants.router, prefix="/v1")
app.include_router(files.router, prefix="/v1")
app.include_router(health.router, prefix="/v1")
Expand Down Expand Up @@ -163,5 +175,5 @@ async def unimplemented(request: Request, full_path: str):
)


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
#if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
Loading

0 comments on commit 22d200a

Please sign in to comment.