From e7e7a671b03e6e94fd62eb0a4ea51149217caf7c Mon Sep 17 00:00:00 2001 From: winternewt Date: Tue, 7 Jan 2025 20:09:11 +0300 Subject: [PATCH] Streaming tests --- core/just_agents/base_agent.py | 10 +- .../interfaces/protocol_adapter.py | 10 +- .../interfaces/streaming_protocol.py | 120 +++++++++++++++++- core/just_agents/protocols/echo_protocol.py | 2 +- .../just_agents/protocols/litellm_protocol.py | 6 +- .../just_agents/protocols/openai_streaming.py | 12 +- poetry.lock | 12 +- tests/test_stream.py | 83 ++++++++++++ 8 files changed, 229 insertions(+), 26 deletions(-) create mode 100644 tests/test_stream.py diff --git a/core/just_agents/base_agent.py b/core/just_agents/base_agent.py index 54bb914..cd3e443 100644 --- a/core/just_agents/base_agent.py +++ b/core/just_agents/base_agent.py @@ -142,7 +142,7 @@ def _execute_completion( self, stream: bool, **kwargs - ) -> Union[SupportedMessages, BaseModelResponse]: + ) -> BaseModelResponse: opt = self._prepare_options(self.llm_options) opt.update(kwargs) @@ -214,16 +214,16 @@ def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwar self._partial_streaming_chunks.clear() for step in range(self.max_tool_calls): response = self._execute_completion(stream=True, **kwargs) - tool_messages: list[SupportedMessages] = [] + tool_messages: SupportedMessages = [] for i, part in enumerate(response): self._partial_streaming_chunks.append(part) - msg: SupportedMessage = self._protocol.message_from_delta(response) # type: ignore + msg : SupportedMessages = self._protocol.delta_from_response(part) delta = self._protocol.content_from_delta(msg) if delta: if reconstruct_chunks: yield self._protocol.get_chunk(i, delta, options={'model': part["model"]}) else: - yield response + yield self._protocol.sse_wrap(part.model_dump(mode='json')) if self.tools and not self._tool_fuse_broken: tool_calls = self._protocol.tool_calls_from_message(msg) if tool_calls: @@ -243,7 +243,7 @@ def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwar yield self._protocol.done() if len(self._partial_streaming_chunks) > 0: response = self._protocol.response_from_deltas(self._partial_streaming_chunks) - msg: SupportedMessage = self._protocol.message_from_response(response) # type: ignore + msg: SupportedMessages = self._protocol.message_from_response(response) # type: ignore self.handle_on_response(msg) self.add_to_memory(msg) self._partial_streaming_chunks.clear() diff --git a/core/just_agents/interfaces/protocol_adapter.py b/core/just_agents/interfaces/protocol_adapter.py index 9d12e36..c24f3ae 100644 --- a/core/just_agents/interfaces/protocol_adapter.py +++ b/core/just_agents/interfaces/protocol_adapter.py @@ -33,11 +33,11 @@ def message_from_response(self, response: BaseModelResponse) -> AbstractMessage: raise NotImplementedError("You need to implement message_from_response first!") @abstractmethod - def message_from_delta(self, response: BaseModelStreamWrapper) -> AbstractMessage: - raise NotImplementedError("You need to implement message_from_delta first!") + def delta_from_response(self, response: BaseModelStreamWrapper) -> AbstractMessage: + raise NotImplementedError("You need to implement delta_from_response first!") @abstractmethod - def content_from_delta(self, delta: BaseModelStreamWrapper) -> str: + def content_from_delta(self, delta: AbstractMessage) -> str: raise NotImplementedError("You need to implement content_from_delta first!") @abstractmethod @@ -45,10 +45,10 @@ def tool_calls_from_message(self, message: AbstractMessage) -> List[IFunctionCal raise NotImplementedError("You need to implement tool_calls_from_response first!") @abstractmethod - def response_from_deltas(self, deltas: List[BaseModelStreamWrapper]) -> BaseModelResponse: + def response_from_deltas(self, deltas: List[AbstractMessage]) -> BaseModelResponse: raise NotImplementedError("You need to implement message_from_deltas first!") - def get_chunk(self, index:int, delta:str, options:dict) -> BaseModelResponse: + def get_chunk(self, index:int, delta:str, options:dict) -> AbstractMessage: return self._output_streaming.get_chunk(index, delta, options) def done(self) -> str: diff --git a/core/just_agents/interfaces/streaming_protocol.py b/core/just_agents/interfaces/streaming_protocol.py index 06beb67..0772d6d 100644 --- a/core/just_agents/interfaces/streaming_protocol.py +++ b/core/just_agents/interfaces/streaming_protocol.py @@ -1,7 +1,125 @@ -from typing import Any +from os import eventfd_read +from typing import Any, Union, Optional, Dict +import json from abc import ABC, abstractmethod class IAbstractStreamingProtocol(ABC): + # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + + @staticmethod + def sse_wrap(data: Union[Dict[str, Any], str], event: Optional[str] = None) -> str: + """ + Prepare a Server-Sent Event (SSE) message string. + + This function constructs a valid SSE message from the given data + and optional event name. The resulting string can be sent as + part of a server-sent event stream, following the required SSE + format: + + event: + data: + + A blank line is appended after the data line to separate this + message from subsequent messages. + + Args: + data (Union[dict, str]): + The data to include in the SSE message body. If a dictionary is + provided, it will be serialized to JSON. If a string is provided, + it will be used as-is. + event (Optional[str]): + The SSE event name. If provided, an "event" field will be included + in the output. + + Returns: + str: + A properly formatted SSE message, including a blank line at the end. + + Raises: + NotImplementedError: + If the data type is not supported by the SSE protocol. + """ + lines = [] + + if event: + # Insert the "event" field only if event is provided + lines.append(f"event: {event}") + + if isinstance(data, str): + lines.append(f"data: {data}") + elif isinstance(data, dict): + # Serialize dictionaries to JSON + lines.append(f"data: {json.dumps(data)}") + else: + raise NotImplementedError("Data type not supported by the SSE protocol.") + + # Append a blank line to separate events + lines.append("") + return "\n".join(lines) + + @staticmethod + def sse_parse(sse_text: str) -> Dict[str, Any]: + """ + Parse a single Server-Sent Event (SSE) message into a dictionary. + + The function looks for the `event:` and `data:` lines in the given text, + extracts their values, and returns them. If multiple lines of `data:` are found, + they will be combined into one string, separated by newlines. + Finally, this function attempts to parse the data string as JSON; if parsing fails, + it will preserve the data as a plain string. + + Example SSE message (single event): + event: chatMessage + data: {"user": "Alice", "message": "Hello!"} + + (Blank line to terminate event) + + Args: + sse_text (str): + The raw text of a single SSE message, including + any `event:` or `data:` lines. + + Returns: + Dict[str, Any]: + A dictionary containing: + - "event" (Optional[str]): The parsed event name, if present. + - "data" (Union[str, dict]): The parsed data, either as JSON (dict) if valid, + or a raw string if JSON parsing fails. Defaults to an empty string if no data is found. + + Raises: + ValueError: + If the input does not contain any `data:` line (since SSE messages + typically contain at least one data line). + """ + # Split lines and strip out empty trailing lines + lines = [line.strip() for line in sse_text.splitlines() if line.strip()] + + event: Optional[str] = None + data_lines = [] + + for line in lines: + if line.startswith("event:"): + event = line.split(":", 1)[1].strip() # Get text after "event:" + elif line.startswith("data:"): + data_lines.append(line.split(":", 1)[1].strip()) # Get text after "data:" + + if not data_lines: + raise ValueError("No data field found in SSE message.") + + # Combine all data lines into one + raw_data = "\n".join(data_lines) + + # Attempt to parse the data as JSON + try: + parsed_data = json.loads(raw_data) + except json.JSONDecodeError: + parsed_data = raw_data + + return { + "event": event, + "data": parsed_data, + } + @abstractmethod def get_chunk(self, index:int, delta:str, options:dict) -> Any: raise NotImplementedError("You need to implement get_chunk() first!") diff --git a/core/just_agents/protocols/echo_protocol.py b/core/just_agents/protocols/echo_protocol.py index 333871d..8367873 100644 --- a/core/just_agents/protocols/echo_protocol.py +++ b/core/just_agents/protocols/echo_protocol.py @@ -186,7 +186,7 @@ def message_from_response(self, response: EchoModelResponse) -> AbstractMessage: "metadata": response.metadata } - def message_from_delta(self, response: EchoModelResponse) -> AbstractMessage: + def delta_from_response(self, response: EchoModelResponse) -> AbstractMessage: """ Convert a delta EchoModelResponse to an abstract message. diff --git a/core/just_agents/protocols/litellm_protocol.py b/core/just_agents/protocols/litellm_protocol.py index 92a9412..43efcfd 100644 --- a/core/just_agents/protocols/litellm_protocol.py +++ b/core/just_agents/protocols/litellm_protocol.py @@ -1,7 +1,7 @@ import json import pprint -from litellm import ModelResponse, CustomStreamWrapper, completion, acompletion, stream_chunk_builder +from litellm import ModelResponse, CustomStreamWrapper, GenericStreamingChunk, completion, acompletion, stream_chunk_builder from typing import Optional, Union, Coroutine, ClassVar, Type, Sequence, List, Any, AsyncGenerator from pydantic import HttpUrl, Field, AliasPath, PrivateAttr, BaseModel, Json, field_validator @@ -38,6 +38,7 @@ class Message(BaseModel): class LiteLLMFunctionCall(BaseModel, IFunctionCall[MessageDict], extra="allow"): id: str = Field(...) + index: Optional[int] = Field(None) name: str = Field(..., validation_alias=AliasPath('function', 'name')) arguments: Json[dict] = Field(..., validation_alias=AliasPath('function', 'arguments')) type: Optional[str] = Field('function') @@ -101,8 +102,7 @@ def message_from_response(self, response: ModelResponse) -> MessageDict: assert "function_call" not in message return message - # TODO: wrong old stuff, YOU DO NOT GET A RESPONSE BUT YOU GET CustomStreamWrapper - def message_from_delta(self, response: CustomStreamWrapper): # ModelResponse) -> MessageDict: + def delta_from_response(self, response: GenericStreamingChunk) -> MessageDict: message = response.choices[0].delta.model_dump( mode="json", exclude_none=True, diff --git a/core/just_agents/protocols/openai_streaming.py b/core/just_agents/protocols/openai_streaming.py index aaed24f..dc31eb5 100644 --- a/core/just_agents/protocols/openai_streaming.py +++ b/core/just_agents/protocols/openai_streaming.py @@ -1,19 +1,21 @@ from just_agents.interfaces.streaming_protocol import IAbstractStreamingProtocol -import json + import time class OpenaiStreamingProtocol(IAbstractStreamingProtocol): + + stop: str = "[DONE]" + def get_chunk(self, index: int, delta: str, options: dict): - chunk = { + chunk : dict = { "id": index, "object": "chat.completion.chunk", "created": time.time(), "model": options["model"], "choices": [{"delta": {"content": delta}}], } - return f"data: {json.dumps(chunk)}\n\n" + return self.sse_wrap(chunk) def done(self): - # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format - return "data: [DONE]\n\n" \ No newline at end of file + return self.sse_wrap(self.stop) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 386257c..3ecd5e9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1080,13 +1080,13 @@ url = "web" [[package]] name = "litellm" -version = "1.57.0" +version = "1.57.1" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.57.0-py3-none-any.whl", hash = "sha256:339aec6f3ecac2035bf6311aa8913ce587c9aca2dc7d72a63a210c659e9721ca"}, - {file = "litellm-1.57.0.tar.gz", hash = "sha256:53a6f2bd9575823e102f7d18dde5cbd2d48eed027cecbb585f18a208605b34c5"}, + {file = "litellm-1.57.1-py3-none-any.whl", hash = "sha256:f9e93689f2d96df3bcebe723d44b6e2e71b9b047ec7ebd1054b6c9bc96cd9515"}, + {file = "litellm-1.57.1.tar.gz", hash = "sha256:2ce6ce1707c92fb278f828a8ea058fa12b3eeb8081dd8c10776569995e03bb6f"}, ] [package.dependencies] @@ -1376,13 +1376,13 @@ test = ["matplotlib", "pytest", "pytest-cov"] [[package]] name = "openai" -version = "1.59.3" +version = "1.59.4" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" files = [ - {file = "openai-1.59.3-py3-none-any.whl", hash = "sha256:b041887a0d8f3e70d1fc6ffbb2bf7661c3b9a2f3e806c04bf42f572b9ac7bc37"}, - {file = "openai-1.59.3.tar.gz", hash = "sha256:7f7fff9d8729968588edf1524e73266e8593bb6cab09298340efb755755bb66f"}, + {file = "openai-1.59.4-py3-none-any.whl", hash = "sha256:82113498699998e98104f87c19a890e82df9b01251a0395484360575d3a1d98a"}, + {file = "openai-1.59.4.tar.gz", hash = "sha256:b946dc5a2308dc1e03efbda80bf1cd64b6053b536851ad519f57ee44401663d2"}, ] [package.dependencies] diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..96c3e3d --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,83 @@ +import json +from dotenv import load_dotenv +import pytest + +from just_agents.base_agent import BaseAgent +from just_agents.llm_options import LLMOptions, LLAMA3_3, OPENAI_GPT4oMINI + +@pytest.fixture(scope="module", autouse=True) +def load_env(): + load_dotenv(override=True) + +def get_current_weather(location: str): + """Gets the current weather in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + +def agent_query(prompt: str, options: LLMOptions): + session: BaseAgent = BaseAgent( + llm_options=options, + tools=[get_current_weather] + ) + return session.query(prompt) + +def agent_call(prompt: str, options: LLMOptions, reconstruct_chunks: bool): + session: BaseAgent = BaseAgent( + llm_options=options, + tools=[get_current_weather] + ) + chunks = [] + gen = session.stream(prompt, reconstruct_chunks=reconstruct_chunks) + for sse_event in gen: + event = session._protocol.sse_parse(sse_event) + assert isinstance(event, dict) + data = event.get("data") + if isinstance(data, dict): + delta = data["choices"][0]["delta"] + chunk = session._protocol.content_from_delta(delta) + else: + continue + chunks.append(chunk) + + full_response = ''.join(chunks) + last = session.memory.last_message_str + assert full_response == last + return full_response + +def test_stream(): + result = agent_call("Why is the sky blue?", OPENAI_GPT4oMINI, False) + print(result) + assert "wavelength" in result + +def test_stream_grok(): + result = agent_call("Why is the sky blue?", LLAMA3_3, False) + assert "wavelength" in result + +def test_stream_recon(): + result = agent_call("Why is the grass green?", OPENAI_GPT4oMINI, True) + assert "chlorophyll" in result + +def test_stream_grok_recon(): + result = agent_call("Why is the grass green?", LLAMA3_3, True) + assert "chlorophyll" in result + +def test_tool_only(): + prompt = "What's the weather like in San Francisco, Tokyo, and Paris?" + non_stream = agent_query(prompt,OPENAI_GPT4oMINI) + assert "72" in non_stream + assert "22" in non_stream + assert "10" in non_stream + +#def test_stream_tool(): +# prompt = "What's the weather like in San Francisco, Tokyo, and Paris?" +# result = agent_call(prompt, OPENAI_GPT4oMINI, False) +# assert "72" in result +# assert "22" in result +# assert "10" in result +