-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from datastax/multi-event-streaming
Multi event streaming
- Loading branch information
Showing
221 changed files
with
18,899 additions
and
3,494 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
v0.1.1 | ||
v0.1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import time | ||
from openai import OpenAI | ||
from dotenv import load_dotenv | ||
from openai.types.beta.assistant_stream_event import ThreadMessageDelta | ||
from streaming_assistants import patch | ||
from openai.lib.streaming import AssistantEventHandler | ||
from typing_extensions import override | ||
|
||
|
||
load_dotenv("./.env") | ||
load_dotenv("../../../.env") | ||
|
||
def run_with_assistant(assistant, client): | ||
print(f"created assistant: {assistant.name}") | ||
print("Uploading file:") | ||
# Upload the file | ||
file = client.files.create( | ||
file=open( | ||
"./examples/python/language_models_are_unsupervised_multitask_learners.pdf", | ||
"rb", | ||
), | ||
purpose="assistants", | ||
) | ||
print("adding file id to assistant") | ||
# Update Assistant | ||
assistant = client.beta.assistants.update( | ||
assistant.id, | ||
tools=[{"type": "retrieval"}], | ||
file_ids=[file.id], | ||
) | ||
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences." | ||
print("creating persistent thread and message") | ||
thread = client.beta.threads.create() | ||
client.beta.threads.messages.create( | ||
thread_id=thread.id, role="user", content=user_message | ||
) | ||
print(f"> {user_message}") | ||
|
||
class EventHandler(AssistantEventHandler): | ||
@override | ||
def on_text_delta(self, delta, snapshot): | ||
# Increment the counter each time the method is called | ||
print(delta.value, end="", flush=True) | ||
|
||
print(f"creating run") | ||
with client.beta.threads.runs.create_and_stream( | ||
thread_id=thread.id, | ||
assistant_id=assistant.id, | ||
event_handler=EventHandler(), | ||
) as stream: | ||
for part in stream: | ||
if not isinstance(part, ThreadMessageDelta): | ||
print(f'received event: {part}\n') | ||
|
||
print("\n") | ||
|
||
|
||
client = patch(OpenAI()) | ||
|
||
instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond and share the exact snippets from the file at the end of your response." | ||
|
||
model = "gpt-3.5-turbo" | ||
name = f"{model} Math Tutor" | ||
|
||
gpt3_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(gpt3_assistant, client) | ||
|
||
model = "cohere/command" | ||
name = f"{model} Math Tutor" | ||
|
||
cohere_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(cohere_assistant, client) | ||
|
||
model = "perplexity/mixtral-8x7b-instruct" | ||
name = f"{model} Math Tutor" | ||
|
||
perplexity_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(perplexity_assistant, client) | ||
|
||
model = "anthropic.claude-v2" | ||
name = f"{model} Math Tutor" | ||
|
||
claude_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(claude_assistant, client) | ||
|
||
model = "gemini/gemini-pro" | ||
name = f"{model} Math Tutor" | ||
|
||
gemini_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(gemini_assistant, client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import asyncio | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
background_task_set = set() | ||
|
||
event_loop = asyncio.new_event_loop() | ||
|
||
async def add_background_task(function, run_id, thread_id, astradb): | ||
logger.debug("Creating background task") | ||
task = asyncio.create_task( | ||
function, name=run_id | ||
) | ||
background_task_set.add(task) | ||
task.add_done_callback(lambda t: on_task_completion(t, astradb=astradb, run_id=run_id, thread_id=thread_id)) | ||
|
||
|
||
def on_task_completion(task, astradb, run_id, thread_id): | ||
background_task_set.remove(task) | ||
logger.debug(f"Task stopped for run_id: {run_id} and thread_id: {thread_id}") | ||
|
||
if task.cancelled(): | ||
logger.warning(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); | ||
return | ||
try: | ||
exception = task.exception() | ||
if exception is not None: | ||
logger.warning(f"Task raised an exception, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
logger.error(exception) | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); | ||
raise exception | ||
else: | ||
logger.debug(f"Task completed successfully for run_id: {run_id} and thread_id: {thread_id}") | ||
except asyncio.CancelledError: | ||
logger.warning(f"why wasn't this caught in task.cancelled()") | ||
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.