Skip to content

Commit

Permalink
refactoring async
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Kulaga committed Jun 12, 2024
1 parent 1dc1394 commit 387ae1f
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 29 deletions.
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- starlette
- jupyter
- pip:
- litellm>=1.40.78
- litellm>=1.40.8
- numpydoc
- semanticscholar>=0.8.1
- Mako>=1.3.5
13 changes: 11 additions & 2 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#from just_agents.chat_agent import ChatAgent
import asyncio
import json
import os
import pprint
Expand All @@ -24,8 +25,12 @@ def get_current_weather(location: str):

llm_options = just_agents.llm_options.LLAMA3
key_getter = rotate_env_keys
prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"

#QWEN 2 does not work!
#llm_options = just_agents.llm_options.OPEN_ROUTER_Qwen_2_72B_Instruct
#key_getter=lambda: os.getenv("OPEN_ROUTER_KEY")

#llm_options = just_agents.llm_options.FIREWORKS_Qwen_2_72B_Instruct
#key_getter=lambda: os.getenv("FIREWORKS_AI_API_KEY")

Expand All @@ -35,5 +40,9 @@ def get_current_weather(location: str):
)
session.memory.add_on_message(lambda m: pprint.pprint(m) if m.content is not None else None)
#session.memory.add_on_message(lambda m: pprint.pprint(m.content) if m.content is not None else None)
session.query("What's the weather like in San Francisco, Tokyo, and Paris?", key_getter=key_getter)
#for QWEN we get: Message(content='{\n "function": "get_current_weather",\n "parameters": {\n "location": ["San Francisco", "Tokyo", "Paris"]\n }\n}', role='assistant')
session.query(prompt, key_getter=key_getter)
#for QWEN we get: Message(content='{\n "function": "get_current_weather",\n "parameters": {\n "location": ["San Francisco", "Tokyo", "Paris"]\n }\n}', role='assistant')


result = asyncio.run(session.stream_async(prompt, key_getter=key_getter))
print("stream finished")
43 changes: 26 additions & 17 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import copy
import pprint
import json
from pathlib import Path
from typing import Any, AsyncGenerator

from litellm.utils import ChatCompletionMessageToolCall, Function
from litellm import ModelResponse, completion, acompletion, Message
from typing import Any, Dict, List, Optional, Callable
import litellm
import json
from just_agents.memory import *
from litellm import ModelResponse, completion
from litellm.utils import Choices

from just_agents.llm_options import LLAMA3
from just_agents.memory import *
from just_agents.memory import Memory
from starlette.responses import ContentStream
import time

from just_agents.streaming.abstract_streaming import AbstractStreaming
from just_agents.streaming.openai_streaming import AsyncSession
from just_agents.streaming.qwen_streaming import QwenStreaming
Expand Down Expand Up @@ -87,16 +82,27 @@ def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]
return self._query(run_callbacks, output, key_getter=key_getter)


def query_all_messages(self, messages: list[dict], run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def query_all_messages(self, messages: list[dict], run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Optional[GetKey] = None) -> str:
self.memory.add_messages(messages, run_callbacks)
return self._query(run_callbacks, output)
return self._query(run_callbacks, output, key_getter=key_getter)


def stream_all(self, messages: list, run_callbacks: bool = True) -> ContentStream:
def stream_all(self, messages: list, run_callbacks: bool = True): # -> ContentStream:
#TODO this function is super-dangerous as it does not seem to clean memory!
#TODO: should we add memory cleanup?
self.memory.add_messages(messages, run_callbacks)
return self._stream()

def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> ContentStream:
async def stream_async(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Callable[[], str] = None) -> List[Any]:
"""temporary function that allows testing the stream function which Alex wrote but I do not fully understand"""
collected_data = []
async for item in self.stream(prompt, run_callbacks, output, key_getter=key_getter):
collected_data.append(item)
# You can also process each item here if needed
return collected_data


def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Callable[[], str] = None) -> AsyncGenerator[Any, None]: # -> ContentStream:
"""
streaming method
:param prompt:
Expand All @@ -108,13 +114,16 @@ def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]
self.memory.add_message(question, run_callbacks)

# Start the streaming process
content_stream = self._stream()
content_stream = self._stream(key_getter=key_getter)

# If output file is provided, write the stream to the file
if output is not None:
try:
with output.open('w') as file:
if isinstance(content_stream, ContentStream):
if True: #if isinstance(content_stream, ContentStream):
#looks like ContentStream is only used for typehinting
# while it brings pretty heavy starlette dependency
# let's temporally comment it out
for content in content_stream:
file.write(content)
else:
Expand All @@ -126,8 +135,8 @@ def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]



def _stream(self) -> ContentStream:
return self.streaming.resp_async_generator(self.memory, self.llm_options, self.available_tools)
def _stream(self, key_getter: Optional[GetKey] = None) -> AsyncGenerator[Any, None]: # -> ContentStream:
return self.streaming.resp_async_generator(self.memory, self.llm_options, self.available_tools, key_getter=key_getter )


def _query(self, run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Optional[GetKey] = None) -> str:
Expand Down
24 changes: 21 additions & 3 deletions just_agents/streaming/abstract_streaming.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import time
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Callable
from typing import Dict, Callable, AsyncGenerator, Optional

from litellm import Message

Expand All @@ -25,8 +25,26 @@ def parsed(self, name: str, arguments: str):


class AbstractStreaming(ABC):
"""
Class that is required to implement the streaming logic
"""

async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]):

@abstractmethod
async def resp_async_generator(
self,
memory: Memory,
options: Dict,
available_tools: Dict[str, Callable],
key_getter: Callable[[], str] = None
) -> AsyncGenerator:
"""
Async generator that fills memory with streaming data
:param memory:
:param options:
:param available_tools:
:return:
"""
pass

def _process_function(self, parser: FunctionParser, available_tools: Dict[str, Callable]):
Expand Down
11 changes: 9 additions & 2 deletions just_agents/streaming/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import AsyncGenerator

from litellm import ModelResponse, completion

from just_agents.memory import *
Expand All @@ -7,8 +9,13 @@

class AsyncSession(AbstractStreaming):

async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
response: ModelResponse = completion(messages=memory.messages, stream=True, **options)
async def resp_async_generator(self, memory: Memory,
options: Dict,
available_tools: Dict[str, Callable],
key_getter: Callable[[], str] = None
) -> AsyncGenerator[str, None]:
api_key = key_getter() if key_getter is not None else None
response: ModelResponse = completion(messages=memory.messages, stream=True, api_key=api_key, **options)
parser: Optional[FunctionParser] = None
tool_messages: list[Message] = []
parsers: list[FunctionParser] = []
Expand Down
7 changes: 4 additions & 3 deletions just_agents/streaming/qwen_streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import Dict, Callable, Optional
from typing import Dict, Callable, Optional, AsyncGenerator
from enum import Enum, auto

from litellm import ModelResponse, completion, Message
Expand Down Expand Up @@ -86,15 +86,16 @@ def _process_function(self, parser: FunctionParser, available_tools: Dict[str, C
tool_call_id=parser.id) # TODO need to track arguments , arguments=function_args
return message

async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable], key_getter: Callable[[], str] = None) -> AsyncGenerator[str, None]:
"""
parses and streams results of the function
:param memory:
:param options:
:param available_tools:
:return:
"""
response: ModelResponse = completion(messages=memory.messages, stream=True, **options)
api_key = key_getter() if key_getter is not None else None
response: ModelResponse = completion(messages=memory.messages, stream=True, api_key=api_key, **options)
parser: QwenFunctionParser = QwenFunctionParser()
deltas: list[str] = []
tool_messages: list[Message] = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
long_description_content_type="text/markdown",
long_description=long_description,
packages=find_packages(),
install_requires=["litellm>=1.40.7", "numpydoc", "loguru", "requests", "Mako", "starlette"],
install_requires=["litellm>=1.40.8", "numpydoc", "loguru", "requests", "Mako"],
extras_require={
'tools': [
# some default tools
Expand Down

0 comments on commit 387ae1f

Please sign in to comment.