From 1dc139436891a5c2e9b2865cd4fc5ecd457221bb Mon Sep 17 00:00:00 2001 From: Anton Kulaga Date: Tue, 11 Jun 2024 15:57:16 +0300 Subject: [PATCH] llmoptions to dict --- environment.yaml | 2 +- examples/function_calling.py | 4 +- just_agents/chat_agent.py | 9 +-- just_agents/llm_options.py | 48 ++++++-------- just_agents/llm_session.py | 48 ++++++++++---- just_agents/memory.py | 9 ++- just_agents/streaming/abstract_streaming.py | 3 +- just_agents/streaming/openai_streaming.py | 6 +- just_agents/streaming/qwen_streaming.py | 72 ++++++++++++--------- just_agents/tools/weather.py | 0 tests/test_session.py | 12 +++- 11 files changed, 127 insertions(+), 86 deletions(-) delete mode 100644 just_agents/tools/weather.py diff --git a/environment.yaml b/environment.yaml index 9606a30..8b4ae7e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,7 +13,7 @@ dependencies: - starlette - jupyter - pip: - - litellm>=1.40.7 + - litellm>=1.40.78 - numpydoc - semanticscholar>=0.8.1 - Mako>=1.3.5 \ No newline at end of file diff --git a/examples/function_calling.py b/examples/function_calling.py index b2041ca..1d78609 100644 --- a/examples/function_calling.py +++ b/examples/function_calling.py @@ -33,5 +33,7 @@ def get_current_weather(location: str): llm_options=llm_options, tools=[get_current_weather] ) -session.memory.add_on_message(lambda m: pprint.pprint(m.content) if m.content is not None else None) +session.memory.add_on_message(lambda m: pprint.pprint(m) if m.content is not None else None) +#session.memory.add_on_message(lambda m: pprint.pprint(m.content) if m.content is not None else None) session.query("What's the weather like in San Francisco, Tokyo, and Paris?", key_getter=key_getter) +#for QWEN we get: Message(content='{\n "function": "get_current_weather",\n "parameters": {\n "location": ["San Francisco", "Tokyo", "Paris"]\n }\n}', role='assistant') \ No newline at end of file diff --git a/just_agents/chat_agent.py b/just_agents/chat_agent.py index bc5eb1a..c788598 100644 --- a/just_agents/chat_agent.py +++ b/just_agents/chat_agent.py @@ -1,13 +1,10 @@ -import dataclasses -import pathlib -import pprint -from string import Template - from dataclasses import dataclass, field +from typing import Any, Dict, Optional + from just_agents.llm_session import LLMSession -from typing import Any, Dict, List, Optional from just_agents.utils import load_config + @dataclass class ChatAgent(LLMSession): """ diff --git a/just_agents/llm_options.py b/just_agents/llm_options.py index 192f4ff..89d3736 100644 --- a/just_agents/llm_options.py +++ b/just_agents/llm_options.py @@ -5,33 +5,27 @@ from dataclasses import dataclass, asdict -@dataclass(frozen=True) -class LLMOptions: - """ - Class for additional LLM options - """ - model: str - api_base: Optional[str] = None - temperature: float = 0.0 - extras: Dict[str, Any] = field(default_factory=lambda: {}) - tools: list = field(default_factory=lambda : []) - tool_choice: Optional[str] = field(default_factory=lambda: None) #"auto" - api_key: str = "" - def to_dict(self): - data = asdict(self) - extras = data.pop('extras', {}) - return {**data, **extras} +LLAMA3: Dict = { + "model": "groq/llama3-70b-8192", + "api_base": "https://api.groq.com/openai/v1", + "temperature": 0.0, + "tools": [] +} +OPEN_ROUTER_Qwen_2_72B_Instruct = { + "model": "openrouter/qwen/qwen-2-72b-instruct", + "temperature": 0.0, + "tools": [] +} - def __getitem__(self, key): - if key in self.extras: - return self.extras[key] - return getattr(self, key) +TOGETHER_Qwen_2_72B_Instruct = { + "model": "openrouter/qwen/qwen-2-72b-instruct", + "api_base": "https://api.groq.com/openai/v1", + "temperature": 0.0, +} - def copy(self, **changes): - return replace(self, **changes) - -LLAMA3: LLMOptions = LLMOptions("groq/llama3-70b-8192", "https://api.groq.com/openai/v1") -OPEN_ROUTER_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("openrouter/qwen/qwen-2-72b-instruct", "https://openrouter.ai/api/v1") -TOGETHER_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("together_ai/Qwen/Qwen2-72B-Instruct") -FIREWORKS_Qwen_2_72B_Instruct: LLMOptions = LLMOptions("fireworks_ai/wen/Qwen2-72B-Instruct") \ No newline at end of file +FIREWORKS_Qwen_2_72B_Instruct = { + "model": "fireworks_ai/wen/Qwen2-72B-Instruct", + "api_base": "https://api.groq.com/openai/v1", + "temperature": 0.0, +} \ No newline at end of file diff --git a/just_agents/llm_session.py b/just_agents/llm_session.py index ccbbcdd..e293405 100644 --- a/just_agents/llm_session.py +++ b/just_agents/llm_session.py @@ -1,3 +1,4 @@ +import copy import pprint from pathlib import Path @@ -9,7 +10,7 @@ from just_agents.memory import * from litellm.utils import Choices -from just_agents.llm_options import LLAMA3, LLMOptions +from just_agents.llm_options import LLAMA3 from just_agents.memory import Memory from starlette.responses import ContentStream import time @@ -27,7 +28,7 @@ @dataclass(kw_only=True) class LLMSession: - llm_options: LLMOptions = field(default_factory=lambda: LLAMA3) + llm_options: Dict[str, Any] = field(default_factory=lambda: LLAMA3) tools: List[Callable] = field(default_factory=list) available_tools: Dict[str, Callable] = field(default_factory=lambda: {}) @@ -36,7 +37,9 @@ class LLMSession: streaming: AbstractStreaming = None def __post_init__(self): - if self.llm_options.model.find("qwen") != -1: + if self.llm_options is not None: + self.llm_options = copy.deepcopy(self.llm_options) #just a satefy requirement to avoid shared dictionaries + if "qwen" in self.llm_options["model"].lower(): self.streaming = QwenStreaming() else: self.streaming = AsyncSession() @@ -84,7 +87,7 @@ def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] return self._query(run_callbacks, output, key_getter=key_getter) - def query_all(self, messages: list, run_callbacks: bool = True, output: Optional[Path] = None) -> str: + def query_all_messages(self, messages: list[dict], run_callbacks: bool = True, output: Optional[Path] = None) -> str: self.memory.add_messages(messages, run_callbacks) return self._query(run_callbacks, output) @@ -93,21 +96,43 @@ def stream_all(self, messages: list, run_callbacks: bool = True) -> ContentStrea self.memory.add_messages(messages, run_callbacks) return self._stream() - def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> ContentStream: + """ + streaming method + :param prompt: + :param run_callbacks: + :param output: + :return: + """ question = Message(role="user", content=prompt) self.memory.add_message(question, run_callbacks) - return self._stream() + + # Start the streaming process + content_stream = self._stream() + + # If output file is provided, write the stream to the file + if output is not None: + try: + with output.open('w') as file: + if isinstance(content_stream, ContentStream): + for content in content_stream: + file.write(content) + else: + raise TypeError("ContentStream expected from self._stream()") + except Exception as e: + print(f"Error writing to file: {e}") + + return content_stream + def _stream(self) -> ContentStream: - return self.streaming.resp_async_generator(self.memory, self.llm_options.to_dict(), self.available_tools) + return self.streaming.resp_async_generator(self.memory, self.llm_options, self.available_tools) def _query(self, run_callbacks: bool = True, output: Optional[Path] = None, key_getter: Optional[GetKey] = None) -> str: - options: Dict = self.llm_options.to_dict() api_key = key_getter() if key_getter is not None else None - response: ModelResponse = completion(messages=self.memory.messages, stream=False, api_key=api_key, **options) + response: ModelResponse = completion(messages=self.memory.messages, stream=False, api_key=api_key, **self.llm_options) self._process_response(response) executed_response = self._process_function_calls(response) if executed_response is not None: @@ -146,7 +171,7 @@ def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResp function_response = str(e) result = Message(role="tool", content=function_response, name=function_name, tool_call_id=tool_call.id) self.memory.add_message(result) - return completion(messages=self.memory.messages, stream=False, **self.llm_options.to_dict()) + return completion(messages=self.memory.messages, stream=False, **self.llm_options) return None def _prepare_tools(self, functions: List[Any]): @@ -162,4 +187,5 @@ def _prepare_tools(self, functions: List[Any]): function_description = litellm.utils.function_to_dict(fun) self.available_tools[function_description["name"]] = fun tools.append({"type": "function", "function": function_description}) - self.llm_options = self.llm_options.copy(tools=tools, tool_choice="auto") \ No newline at end of file + self.llm_options["tools"] = tools + self.llm_options["tool_choice"] = "auto" \ No newline at end of file diff --git a/just_agents/memory.py b/just_agents/memory.py index 27e492b..c261a9c 100644 --- a/just_agents/memory.py +++ b/just_agents/memory.py @@ -62,10 +62,13 @@ def last_message(self) -> Optional[Message]: return self.messages[-1] if len(self.messages) > 0 else None - def add_messages(self, messages: list, run_callbacks: bool = True): + def add_messages(self, messages: list[Message | dict], run_callbacks: bool = True): for message in messages: - msg = Message(content=message["content"], role=message["role"]) - self.messages.append(msg) + if message is Message: + self.messages.append(message) + else: + msg = Message(content=message["content"], role=message["role"]) + self.messages.append(msg) if run_callbacks: for handler in self.on_message: handler(message) diff --git a/just_agents/streaming/abstract_streaming.py b/just_agents/streaming/abstract_streaming.py index 1773c62..2c126e0 100644 --- a/just_agents/streaming/abstract_streaming.py +++ b/just_agents/streaming/abstract_streaming.py @@ -1,5 +1,6 @@ import json import time +from abc import ABC from dataclasses import dataclass from typing import Dict, Callable @@ -23,7 +24,7 @@ def parsed(self, name: str, arguments: str): return False -class AbstractStreaming: +class AbstractStreaming(ABC): async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]): pass diff --git a/just_agents/streaming/openai_streaming.py b/just_agents/streaming/openai_streaming.py index ddf3ff5..cfe409b 100644 --- a/just_agents/streaming/openai_streaming.py +++ b/just_agents/streaming/openai_streaming.py @@ -1,9 +1,7 @@ -from litellm import ModelResponse, completion, acompletion, Message -import json +from litellm import ModelResponse, completion + from just_agents.memory import * from just_agents.memory import Memory -import time - from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser diff --git a/just_agents/streaming/qwen_streaming.py b/just_agents/streaming/qwen_streaming.py index 6e7dc6c..4b87de1 100644 --- a/just_agents/streaming/qwen_streaming.py +++ b/just_agents/streaming/qwen_streaming.py @@ -1,62 +1,68 @@ import json from dataclasses import dataclass from typing import Dict, Callable, Optional +from enum import Enum, auto from litellm import ModelResponse, completion, Message - from just_agents.memory import Memory from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser -CLEARED:int = -2 -STOPED:int = -1 -WAITING:int = 0 -UNDETERMIND:int = 1 -PARSING:int = 2 +class ParserState(Enum): + """ + States for the Qwen parser + """ + CLEARED = auto() + STOPPED = auto() + WAITING = auto() + UNDETERMINED = auto() + PARSING = auto() @dataclass class QwenFunctionParser: - name:str = "" - arguments:str = "" - buffer:str = "" - state:int = WAITING - - def parsing(self, token:str) -> bool: - if self.state == STOPED or self.state == CLEARED: + """ + Qwen has differences in formats of how it streams stuff + """ + name: str = "" + arguments: str = "" + buffer: str = "" + state: ParserState = ParserState.WAITING + + def parsing(self, token: str) -> bool: + if self.state in {ParserState.STOPPED, ParserState.CLEARED}: return False self.buffer += token - if self.state == PARSING: + if self.state == ParserState.PARSING: return True - if self.state == WAITING and token.startswith("{"): - self.state = UNDETERMIND + if self.state == ParserState.WAITING and token.startswith("{"): + self.state = ParserState.UNDETERMINED return True - if self.state == UNDETERMIND and len(self.buffer) < 12: + if self.state == ParserState.UNDETERMINED and len(self.buffer) < 12: return True - if self.state == UNDETERMIND and len(self.buffer) >= 12: - if str(self.buffer).find("function") != -1: - self.state = PARSING + if self.state == ParserState.UNDETERMINED and len(self.buffer) >= 12: + if "function" in self.buffer: + self.state = ParserState.PARSING return True else: - self.state = STOPED + self.state = ParserState.STOPPED return False def need_cleared(self) -> bool: - return self.state == STOPED - + return self.state == ParserState.STOPPED def clear(self) -> str: - self.state = CLEARED + self.state = ParserState.CLEARED return self.buffer def is_ready(self): - return self.state == PARSING + return self.state == ParserState.PARSING def get_function_parsers(self): - if self.state == PARSING: + if self.state == ParserState.PARSING: res = [] data = self.buffer.replace("\n", "") data = "[" + data.replace("}}", "}},")[:-1] + "]" @@ -66,7 +72,6 @@ def get_function_parsers(self): return res return [] - @dataclass class QwenStreaming(AbstractStreaming): @@ -78,16 +83,21 @@ def _process_function(self, parser: FunctionParser, available_tools: Dict[str, C except Exception as e: function_response = str(e) message = Message(role="function", content=function_response, name=parser.name, - tool_call_id=parser.id) # TODO need to track arguments , arguments=function_args + tool_call_id=parser.id) # TODO need to track arguments , arguments=function_args return message - async def resp_async_generator(self, memory: Memory, options: Dict, available_tools: Dict[str, Callable]): + """ + parses and streams results of the function + :param memory: + :param options: + :param available_tools: + :return: + """ response: ModelResponse = completion(messages=memory.messages, stream=True, **options) parser: QwenFunctionParser = QwenFunctionParser() deltas: list[str] = [] tool_messages: list[Message] = [] - parsers: list[FunctionParser] = [] for i, part in enumerate(response): delta: str = part["choices"][0]["delta"].get("content") # type: ignore @@ -118,4 +128,4 @@ async def resp_async_generator(self, memory: Memory, options: Dict, available_to elif len(deltas) > 0: memory.add_message(Message(role="assistant", content="".join(deltas))) - yield "data: [DONE]\n\n" + yield "data: [DONE]\n\n" \ No newline at end of file diff --git a/just_agents/tools/weather.py b/just_agents/tools/weather.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_session.py b/tests/test_session.py index 3f3a9da..6f31c1a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -19,7 +19,7 @@ def get_current_weather(location: str): else: return json.dumps({"location": location, "temperature": "unknown"}) -def test_function_calling(): +def test_sync_llama_function_calling(): load_dotenv() session: LLMSession = LLMSession( llm_options=just_agents.llm_options.LLAMA3, @@ -27,4 +27,14 @@ def test_function_calling(): ) result = session.query("What's the weather like in San Francisco, Tokyo, and Paris?") assert "72°F" in result + assert "22" in result + +def test_async_gwen2_function_calling(): + load_dotenv() + session: LLMSession = LLMSession( + llm_options=just_agents.llm_options.OPEN_ROUTER_Qwen_2_72B_Instruct, + tools=[get_current_weather] + ) + result = session.query_all_messages("What's the weather like in San Francisco, Tokyo, and Paris?") + assert "72°F" in result assert "22" in result \ No newline at end of file