Skip to content

Commit

Permalink
llmoptions to dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Kulaga committed Jun 11, 2024
1 parent a77ba19 commit 1dc1394
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 86 deletions.
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- starlette
- jupyter
- pip:
- litellm>=1.40.7
- litellm>=1.40.78
- numpydoc
- semanticscholar>=0.8.1
- Mako>=1.3.5
4 changes: 3 additions & 1 deletion examples/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,7 @@ def get_current_weather(location: str):
llm_options=llm_options,
tools=[get_current_weather]
)
session.memory.add_on_message(lambda m: pprint.pprint(m.content) if m.content is not None else None)
session.memory.add_on_message(lambda m: pprint.pprint(m) if m.content is not None else None)
#session.memory.add_on_message(lambda m: pprint.pprint(m.content) if m.content is not None else None)
session.query("What's the weather like in San Francisco, Tokyo, and Paris?", key_getter=key_getter)
#for QWEN we get: Message(content='{\n "function": "get_current_weather",\n "parameters": {\n "location": ["San Francisco", "Tokyo", "Paris"]\n }\n}', role='assistant')
9 changes: 3 additions & 6 deletions just_agents/chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import dataclasses
import pathlib
import pprint
from string import Template

from dataclasses import dataclass, field
from typing import Any, Dict, Optional

from just_agents.llm_session import LLMSession
from typing import Any, Dict, List, Optional
from just_agents.utils import load_config


@dataclass
class ChatAgent(LLMSession):
"""
Expand Down
48 changes: 21 additions & 27 deletions just_agents/llm_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,27 @@

from dataclasses import dataclass, asdict

@dataclass(frozen=True)
class LLMOptions:
"""
Class for additional LLM options
"""
model: str
api_base: Optional[str] = None
temperature: float = 0.0
extras: Dict[str, Any] = field(default_factory=lambda: {})
tools: list = field(default_factory=lambda : [])
tool_choice: Optional[str] = field(default_factory=lambda: None) #"auto"
api_key: str = ""

def to_dict(self):
data = asdict(self)
extras = data.pop('extras', {})
return {**data, **extras}
LLAMA3: Dict = {
"model": "groq/llama3-70b-8192",
"api_base": "https://api.groq.com/openai/v1",
"temperature": 0.0,
"tools": []
}
OPEN_ROUTER_Qwen_2_72B_Instruct = {
"model": "openrouter/qwen/qwen-2-72b-instruct",
"temperature": 0.0,
"tools": []
}

def __getitem__(self, key):
if key in self.extras:
return self.extras[key]
return getattr(self, key)
TOGETHER_Qwen_2_72B_Instruct = {
"model": "openrouter/qwen/qwen-2-72b-instruct",
"api_base": "https://api.groq.com/openai/v1",
"temperature": 0.0,
}

def copy(self, **changes):
return replace(self, **changes)

LLAMA3: LLMOptions = LLMOptions("groq/llama3-70b-8192", "https://api.groq.com/openai/v1")
OPEN_ROUTER_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("openrouter/qwen/qwen-2-72b-instruct", "https://openrouter.ai/api/v1")
TOGETHER_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("together_ai/Qwen/Qwen2-72B-Instruct")
FIREWORKS_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("fireworks_ai/wen/Qwen2-72B-Instruct")
FIREWORKS_Qwen_2_72B_Instruct = {
"model": "fireworks_ai/wen/Qwen2-72B-Instruct",
"api_base": "https://api.groq.com/openai/v1",
"temperature": 0.0,
}
48 changes: 37 additions & 11 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import pprint
from pathlib import Path

Expand All @@ -9,7 +10,7 @@
from just_agents.memory import *
from litellm.utils import Choices

from just_agents.llm_options import LLAMA3, LLMOptions
from just_agents.llm_options import LLAMA3
from just_agents.memory import Memory
from starlette.responses import ContentStream
import time
Expand All @@ -27,7 +28,7 @@

@dataclass(kw_only=True)
class LLMSession:
llm_options: LLMOptions = field(default_factory=lambda: LLAMA3)
llm_options: Dict[str, Any] = field(default_factory=lambda: LLAMA3)
tools: List[Callable] = field(default_factory=list)
available_tools: Dict[str, Callable] = field(default_factory=lambda: {})

Expand All @@ -36,7 +37,9 @@ class LLMSession:
streaming: AbstractStreaming = None

def __post_init__(self):
if self.llm_options.model.find("qwen") != -1:
if self.llm_options is not None:
self.llm_options = copy.deepcopy(self.llm_options) #just a satefy requirement to avoid shared dictionaries
if "qwen" in self.llm_options["model"].lower():
self.streaming = QwenStreaming()
else:
self.streaming = AsyncSession()
Expand Down Expand Up @@ -84,7 +87,7 @@ def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]
return self._query(run_callbacks, output, key_getter=key_getter)


def query_all(self, messages: list, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def query_all_messages(self, messages: list[dict], run_callbacks: bool = True, output: Optional[Path] = None) -> str:
self.memory.add_messages(messages, run_callbacks)
return self._query(run_callbacks, output)

Expand All @@ -93,21 +96,43 @@ def stream_all(self, messages: list, run_callbacks: bool = True) -> ContentStrea
self.memory.add_messages(messages, run_callbacks)
return self._stream()


def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> ContentStream:
"""
streaming method
:param prompt:
:param run_callbacks:
:param output:
:return:
"""
question = Message(role="user", content=prompt)
self.memory.add_message(question, run_callbacks)
return self._stream()

# Start the streaming process
content_stream = self._stream()

# If output file is provided, write the stream to the file
if output is not None:
try:
with output.open('w') as file:
if isinstance(content_stream, ContentStream):
for content in content_stream:
file.write(content)
else:
raise TypeError("ContentStream expected from self._stream()")
except Exception as e:
print(f"Error writing to file: {e}")

return content_stream



def _stream(self) -> ContentStream:
return self.streaming.resp_async_generator(self.memory, self.llm_options.to_dict(), self.available_tools)
return self.streaming.resp_async_generator(self.memory, self.llm_options, self.available_tools)


def _query(self, run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Optional[GetKey] = None) -> str:
options: Dict = self.llm_options.to_dict()
api_key = key_getter() if key_getter is not None else None
response: ModelResponse = completion(messages=self.memory.messages, stream=False, api_key=api_key, **options)
response: ModelResponse = completion(messages=self.memory.messages, stream=False, api_key=api_key, **self.llm_options)
self._process_response(response)
executed_response = self._process_function_calls(response)
if executed_response is not None:
Expand Down Expand Up @@ -146,7 +171,7 @@ def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResp
function_response = str(e)
result = Message(role="tool", content=function_response, name=function_name, tool_call_id=tool_call.id)
self.memory.add_message(result)
return completion(messages=self.memory.messages, stream=False, **self.llm_options.to_dict())
return completion(messages=self.memory.messages, stream=False, **self.llm_options)
return None

def _prepare_tools(self, functions: List[Any]):
Expand All @@ -162,4 +187,5 @@ def _prepare_tools(self, functions: List[Any]):
function_description = litellm.utils.function_to_dict(fun)
self.available_tools[function_description["name"]] = fun
tools.append({"type": "function", "function": function_description})
self.llm_options = self.llm_options.copy(tools=tools, tool_choice="auto")
self.llm_options["tools"] = tools
self.llm_options["tool_choice"] = "auto"
9 changes: 6 additions & 3 deletions just_agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def last_message(self) -> Optional[Message]:
return self.messages[-1] if len(self.messages) > 0 else None


def add_messages(self, messages: list, run_callbacks: bool = True):
def add_messages(self, messages: list[Message | dict], run_callbacks: bool = True):
for message in messages:
msg = Message(content=message["content"], role=message["role"])
self.messages.append(msg)
if message is Message:
self.messages.append(message)
else:
msg = Message(content=message["content"], role=message["role"])
self.messages.append(msg)
if run_callbacks:
for handler in self.on_message:
handler(message)
3 changes: 2 additions & 1 deletion just_agents/streaming/abstract_streaming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import time
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Callable

Expand All @@ -23,7 +24,7 @@ def parsed(self, name: str, arguments: str):
return False


class AbstractStreaming:
class AbstractStreaming(ABC):

async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
pass
Expand Down
6 changes: 2 additions & 4 deletions just_agents/streaming/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from litellm import ModelResponse, completion, acompletion, Message
import json
from litellm import ModelResponse, completion

from just_agents.memory import *
from just_agents.memory import Memory
import time

from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser


Expand Down
72 changes: 41 additions & 31 deletions just_agents/streaming/qwen_streaming.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,68 @@
import json
from dataclasses import dataclass
from typing import Dict, Callable, Optional
from enum import Enum, auto

from litellm import ModelResponse, completion, Message

from just_agents.memory import Memory
from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser

CLEARED:int = -2
STOPED:int = -1
WAITING:int = 0
UNDETERMIND:int = 1
PARSING:int = 2
class ParserState(Enum):
"""
States for the Qwen parser
"""
CLEARED = auto()
STOPPED = auto()
WAITING = auto()
UNDETERMINED = auto()
PARSING = auto()

@dataclass
class QwenFunctionParser:
name:str = ""
arguments:str = ""
buffer:str = ""
state:int = WAITING

def parsing(self, token:str) -> bool:
if self.state == STOPED or self.state == CLEARED:
"""
Qwen has differences in formats of how it streams stuff
"""
name: str = ""
arguments: str = ""
buffer: str = ""
state: ParserState = ParserState.WAITING

def parsing(self, token: str) -> bool:
if self.state in {ParserState.STOPPED, ParserState.CLEARED}:
return False

self.buffer += token

if self.state == PARSING:
if self.state == ParserState.PARSING:
return True

if self.state == WAITING and token.startswith("{"):
self.state = UNDETERMIND
if self.state == ParserState.WAITING and token.startswith("{"):
self.state = ParserState.UNDETERMINED
return True

if self.state == UNDETERMIND and len(self.buffer) < 12:
if self.state == ParserState.UNDETERMINED and len(self.buffer) < 12:
return True

if self.state == UNDETERMIND and len(self.buffer) >= 12:
if str(self.buffer).find("function") != -1:
self.state = PARSING
if self.state == ParserState.UNDETERMINED and len(self.buffer) >= 12:
if "function" in self.buffer:
self.state = ParserState.PARSING
return True
else:
self.state = STOPED
self.state = ParserState.STOPPED
return False

def need_cleared(self) -> bool:
return self.state == STOPED

return self.state == ParserState.STOPPED

def clear(self) -> str:
self.state = CLEARED
self.state = ParserState.CLEARED
return self.buffer

def is_ready(self):
return self.state == PARSING
return self.state == ParserState.PARSING

def get_function_parsers(self):
if self.state == PARSING:
if self.state == ParserState.PARSING:
res = []
data = self.buffer.replace("\n", "")
data = "[" + data.replace("}}", "}},")[:-1] + "]"
Expand All @@ -66,7 +72,6 @@ def get_function_parsers(self):
return res
return []


@dataclass
class QwenStreaming(AbstractStreaming):

Expand All @@ -78,16 +83,21 @@ def _process_function(self, parser: FunctionParser, available_tools: Dict[str, C
except Exception as e:
function_response = str(e)
message = Message(role="function", content=function_response, name=parser.name,
tool_call_id=parser.id) # TODO need to track arguments , arguments=function_args
tool_call_id=parser.id) # TODO need to track arguments , arguments=function_args
return message


async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
"""
parses and streams results of the function
:param memory:
:param options:
:param available_tools:
:return:
"""
response: ModelResponse = completion(messages=memory.messages, stream=True, **options)
parser: QwenFunctionParser = QwenFunctionParser()
deltas: list[str] = []
tool_messages: list[Message] = []
parsers: list[FunctionParser] = []

for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
Expand Down Expand Up @@ -118,4 +128,4 @@ async def resp_async_generator(self, memory: Memory, options: Dict, available_to
elif len(deltas) > 0:
memory.add_message(Message(role="assistant", content="".join(deltas)))

yield "data: [DONE]\n\n"
yield "data: [DONE]\n\n"
Empty file removed just_agents/tools/weather.py
Empty file.
Loading

0 comments on commit 1dc1394

Please sign in to comment.