Skip to content

Commit

Permalink
Streaming tests
Browse files Browse the repository at this point in the history
  • Loading branch information
winternewt committed Jan 7, 2025
1 parent 343924a commit e7e7a67
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 26 deletions.
10 changes: 5 additions & 5 deletions core/just_agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _execute_completion(
self,
stream: bool,
**kwargs
) -> Union[SupportedMessages, BaseModelResponse]:
) -> BaseModelResponse:

opt = self._prepare_options(self.llm_options)
opt.update(kwargs)
Expand Down Expand Up @@ -214,16 +214,16 @@ def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwar
self._partial_streaming_chunks.clear()
for step in range(self.max_tool_calls):
response = self._execute_completion(stream=True, **kwargs)
tool_messages: list[SupportedMessages] = []
tool_messages: SupportedMessages = []
for i, part in enumerate(response):
self._partial_streaming_chunks.append(part)
msg: SupportedMessage = self._protocol.message_from_delta(response) # type: ignore
msg : SupportedMessages = self._protocol.delta_from_response(part)
delta = self._protocol.content_from_delta(msg)
if delta:
if reconstruct_chunks:
yield self._protocol.get_chunk(i, delta, options={'model': part["model"]})
else:
yield response
yield self._protocol.sse_wrap(part.model_dump(mode='json'))
if self.tools and not self._tool_fuse_broken:
tool_calls = self._protocol.tool_calls_from_message(msg)
if tool_calls:
Expand All @@ -243,7 +243,7 @@ def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwar
yield self._protocol.done()
if len(self._partial_streaming_chunks) > 0:
response = self._protocol.response_from_deltas(self._partial_streaming_chunks)
msg: SupportedMessage = self._protocol.message_from_response(response) # type: ignore
msg: SupportedMessages = self._protocol.message_from_response(response) # type: ignore
self.handle_on_response(msg)
self.add_to_memory(msg)
self._partial_streaming_chunks.clear()
Expand Down
10 changes: 5 additions & 5 deletions core/just_agents/interfaces/protocol_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ def message_from_response(self, response: BaseModelResponse) -> AbstractMessage:
raise NotImplementedError("You need to implement message_from_response first!")

@abstractmethod
def message_from_delta(self, response: BaseModelStreamWrapper) -> AbstractMessage:
raise NotImplementedError("You need to implement message_from_delta first!")
def delta_from_response(self, response: BaseModelStreamWrapper) -> AbstractMessage:
raise NotImplementedError("You need to implement delta_from_response first!")

@abstractmethod
def content_from_delta(self, delta: BaseModelStreamWrapper) -> str:
def content_from_delta(self, delta: AbstractMessage) -> str:
raise NotImplementedError("You need to implement content_from_delta first!")

@abstractmethod
def tool_calls_from_message(self, message: AbstractMessage) -> List[IFunctionCall[AbstractMessage]]:
raise NotImplementedError("You need to implement tool_calls_from_response first!")

@abstractmethod
def response_from_deltas(self, deltas: List[BaseModelStreamWrapper]) -> BaseModelResponse:
def response_from_deltas(self, deltas: List[AbstractMessage]) -> BaseModelResponse:
raise NotImplementedError("You need to implement message_from_deltas first!")

def get_chunk(self, index:int, delta:str, options:dict) -> BaseModelResponse:
def get_chunk(self, index:int, delta:str, options:dict) -> AbstractMessage:
return self._output_streaming.get_chunk(index, delta, options)

def done(self) -> str:
Expand Down
120 changes: 119 additions & 1 deletion core/just_agents/interfaces/streaming_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,125 @@
from typing import Any
from os import eventfd_read
from typing import Any, Union, Optional, Dict
import json
from abc import ABC, abstractmethod

class IAbstractStreamingProtocol(ABC):
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format

@staticmethod
def sse_wrap(data: Union[Dict[str, Any], str], event: Optional[str] = None) -> str:
"""
Prepare a Server-Sent Event (SSE) message string.
This function constructs a valid SSE message from the given data
and optional event name. The resulting string can be sent as
part of a server-sent event stream, following the required SSE
format:
event: <event_name>
data: <data>
A blank line is appended after the data line to separate this
message from subsequent messages.
Args:
data (Union[dict, str]):
The data to include in the SSE message body. If a dictionary is
provided, it will be serialized to JSON. If a string is provided,
it will be used as-is.
event (Optional[str]):
The SSE event name. If provided, an "event" field will be included
in the output.
Returns:
str:
A properly formatted SSE message, including a blank line at the end.
Raises:
NotImplementedError:
If the data type is not supported by the SSE protocol.
"""
lines = []

if event:
# Insert the "event" field only if event is provided
lines.append(f"event: {event}")

if isinstance(data, str):
lines.append(f"data: {data}")
elif isinstance(data, dict):
# Serialize dictionaries to JSON
lines.append(f"data: {json.dumps(data)}")
else:
raise NotImplementedError("Data type not supported by the SSE protocol.")

# Append a blank line to separate events
lines.append("")
return "\n".join(lines)

@staticmethod
def sse_parse(sse_text: str) -> Dict[str, Any]:
"""
Parse a single Server-Sent Event (SSE) message into a dictionary.
The function looks for the `event:` and `data:` lines in the given text,
extracts their values, and returns them. If multiple lines of `data:` are found,
they will be combined into one string, separated by newlines.
Finally, this function attempts to parse the data string as JSON; if parsing fails,
it will preserve the data as a plain string.
Example SSE message (single event):
event: chatMessage
data: {"user": "Alice", "message": "Hello!"}
(Blank line to terminate event)
Args:
sse_text (str):
The raw text of a single SSE message, including
any `event:` or `data:` lines.
Returns:
Dict[str, Any]:
A dictionary containing:
- "event" (Optional[str]): The parsed event name, if present.
- "data" (Union[str, dict]): The parsed data, either as JSON (dict) if valid,
or a raw string if JSON parsing fails. Defaults to an empty string if no data is found.
Raises:
ValueError:
If the input does not contain any `data:` line (since SSE messages
typically contain at least one data line).
"""
# Split lines and strip out empty trailing lines
lines = [line.strip() for line in sse_text.splitlines() if line.strip()]

event: Optional[str] = None
data_lines = []

for line in lines:
if line.startswith("event:"):
event = line.split(":", 1)[1].strip() # Get text after "event:"
elif line.startswith("data:"):
data_lines.append(line.split(":", 1)[1].strip()) # Get text after "data:"

if not data_lines:
raise ValueError("No data field found in SSE message.")

# Combine all data lines into one
raw_data = "\n".join(data_lines)

# Attempt to parse the data as JSON
try:
parsed_data = json.loads(raw_data)
except json.JSONDecodeError:
parsed_data = raw_data

return {
"event": event,
"data": parsed_data,
}

@abstractmethod
def get_chunk(self, index:int, delta:str, options:dict) -> Any:
raise NotImplementedError("You need to implement get_chunk() first!")
Expand Down
2 changes: 1 addition & 1 deletion core/just_agents/protocols/echo_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def message_from_response(self, response: EchoModelResponse) -> AbstractMessage:
"metadata": response.metadata
}

def message_from_delta(self, response: EchoModelResponse) -> AbstractMessage:
def delta_from_response(self, response: EchoModelResponse) -> AbstractMessage:
"""
Convert a delta EchoModelResponse to an abstract message.
Expand Down
6 changes: 3 additions & 3 deletions core/just_agents/protocols/litellm_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import pprint

from litellm import ModelResponse, CustomStreamWrapper, completion, acompletion, stream_chunk_builder
from litellm import ModelResponse, CustomStreamWrapper, GenericStreamingChunk, completion, acompletion, stream_chunk_builder
from typing import Optional, Union, Coroutine, ClassVar, Type, Sequence, List, Any, AsyncGenerator
from pydantic import HttpUrl, Field, AliasPath, PrivateAttr, BaseModel, Json, field_validator

Expand Down Expand Up @@ -38,6 +38,7 @@ class Message(BaseModel):

class LiteLLMFunctionCall(BaseModel, IFunctionCall[MessageDict], extra="allow"):
id: str = Field(...)
index: Optional[int] = Field(None)
name: str = Field(..., validation_alias=AliasPath('function', 'name'))
arguments: Json[dict] = Field(..., validation_alias=AliasPath('function', 'arguments'))
type: Optional[str] = Field('function')
Expand Down Expand Up @@ -101,8 +102,7 @@ def message_from_response(self, response: ModelResponse) -> MessageDict:
assert "function_call" not in message
return message

# TODO: wrong old stuff, YOU DO NOT GET A RESPONSE BUT YOU GET CustomStreamWrapper
def message_from_delta(self, response: CustomStreamWrapper): # ModelResponse) -> MessageDict:
def delta_from_response(self, response: GenericStreamingChunk) -> MessageDict:
message = response.choices[0].delta.model_dump(
mode="json",
exclude_none=True,
Expand Down
12 changes: 7 additions & 5 deletions core/just_agents/protocols/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from just_agents.interfaces.streaming_protocol import IAbstractStreamingProtocol
import json

import time

class OpenaiStreamingProtocol(IAbstractStreamingProtocol):

stop: str = "[DONE]"

def get_chunk(self, index: int, delta: str, options: dict):
chunk = {
chunk : dict = {
"id": index,
"object": "chat.completion.chunk",
"created": time.time(),
"model": options["model"],
"choices": [{"delta": {"content": delta}}],
}
return f"data: {json.dumps(chunk)}\n\n"
return self.sse_wrap(chunk)


def done(self):
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
return "data: [DONE]\n\n"
return self.sse_wrap(self.stop)
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 83 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from dotenv import load_dotenv
import pytest

from just_agents.base_agent import BaseAgent
from just_agents.llm_options import LLMOptions, LLAMA3_3, OPENAI_GPT4oMINI

@pytest.fixture(scope="module", autouse=True)
def load_env():
load_dotenv(override=True)

def get_current_weather(location: str):
"""Gets the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"})
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else:
return json.dumps({"location": location, "temperature": "unknown"})

def agent_query(prompt: str, options: LLMOptions):
session: BaseAgent = BaseAgent(
llm_options=options,
tools=[get_current_weather]
)
return session.query(prompt)

def agent_call(prompt: str, options: LLMOptions, reconstruct_chunks: bool):
session: BaseAgent = BaseAgent(
llm_options=options,
tools=[get_current_weather]
)
chunks = []
gen = session.stream(prompt, reconstruct_chunks=reconstruct_chunks)
for sse_event in gen:
event = session._protocol.sse_parse(sse_event)
assert isinstance(event, dict)
data = event.get("data")
if isinstance(data, dict):
delta = data["choices"][0]["delta"]
chunk = session._protocol.content_from_delta(delta)
else:
continue
chunks.append(chunk)

full_response = ''.join(chunks)
last = session.memory.last_message_str
assert full_response == last
return full_response

def test_stream():
result = agent_call("Why is the sky blue?", OPENAI_GPT4oMINI, False)
print(result)
assert "wavelength" in result

def test_stream_grok():
result = agent_call("Why is the sky blue?", LLAMA3_3, False)
assert "wavelength" in result

def test_stream_recon():
result = agent_call("Why is the grass green?", OPENAI_GPT4oMINI, True)
assert "chlorophyll" in result

def test_stream_grok_recon():
result = agent_call("Why is the grass green?", LLAMA3_3, True)
assert "chlorophyll" in result

def test_tool_only():
prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"
non_stream = agent_query(prompt,OPENAI_GPT4oMINI)
assert "72" in non_stream
assert "22" in non_stream
assert "10" in non_stream

#def test_stream_tool():
# prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"
# result = agent_call(prompt, OPENAI_GPT4oMINI, False)
# assert "72" in result
# assert "22" in result
# assert "10" in result

0 comments on commit e7e7a67

Please sign in to comment.