Skip to content

Commit

Permalink
Straming tests complete and passing
Browse files Browse the repository at this point in the history
  • Loading branch information
winternewt committed Jan 7, 2025
1 parent e7e7a67 commit e431395
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 54 deletions.
82 changes: 46 additions & 36 deletions core/just_agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,8 @@ def query_with_current_memory(self, **kwargs): #former proceed() aka llm_think()
self.add_to_memory(msg)

if not self.tools or self._tool_fuse_broken:
self._tool_fuse_broken = False
break
# If there are no tool calls or tools available, exit the loop
break # If there are no tool calls or tools available, exit the loop

tool_calls = self._protocol.tool_calls_from_message(msg)
# Process each tool call if they exist and re-execute query
self._process_function_calls(
Expand All @@ -208,46 +207,57 @@ def query_with_current_memory(self, **kwargs): #former proceed() aka llm_think()
elif step == self.max_tool_calls - 2: #special case where we ran out of tool calls or stuck in a loop
self._tool_fuse_broken = True #one last attempt at graceful response

self._tool_fuse_broken = False

def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwargs):
try:

for step in range(self.max_tool_calls):
self._partial_streaming_chunks.clear()
for step in range(self.max_tool_calls):
response = self._execute_completion(stream=True, **kwargs)
tool_messages: SupportedMessages = []
for i, part in enumerate(response):
self._partial_streaming_chunks.append(part)
msg : SupportedMessages = self._protocol.delta_from_response(part)
delta = self._protocol.content_from_delta(msg)
if delta:
if reconstruct_chunks:
yield self._protocol.get_chunk(i, delta, options={'model': part["model"]})
else:
yield self._protocol.sse_wrap(part.model_dump(mode='json'))
if self.tools and not self._tool_fuse_broken:
tool_calls = self._protocol.tool_calls_from_message(msg)
if tool_calls:
self.add_to_memory(
self._protocol.function_convention.reconstruct_tool_call_message(tool_calls)
)
self._process_function_calls(tool_calls)
tool_messages.append(self._process_function_calls(tool_calls))

if not tool_messages:
break
elif step == self.max_tool_calls - 2: # special case where we ran out of tool calls or stuck in a loop
self._tool_fuse_broken = True # one last attempt at graceful response

finally:
self._tool_fuse_broken = False #defuse
yield self._protocol.done()
response = self._execute_completion(stream=True, **kwargs)
yielded = False
tool_calls = []
for i, part in enumerate(response):
self._partial_streaming_chunks.append(part)
msg : SupportedMessages = self._protocol.delta_from_response(part)
delta = self._protocol.content_from_delta(msg)
if delta: #stream content as is
yielded = True
if reconstruct_chunks:
yield self._protocol.get_chunk(i, delta, options={'model': part["model"]})
else:
yield self._protocol.sse_wrap(part.model_dump(mode='json'))

# if self.tools and not self._tool_fuse_broken:
# tool_calls = self._protocol.tool_calls_from_message(msg)
# if tool_calls:
# self.add_to_memory(
# self._protocol.function_convention.reconstruct_tool_call_message(tool_calls)
# )
# self._process_function_calls(tool_calls)
# tool_messages.append(self._process_function_calls(tool_calls))

if len(self._partial_streaming_chunks) > 0:
response = self._protocol.response_from_deltas(self._partial_streaming_chunks)
msg: SupportedMessages = self._protocol.message_from_response(response) # type: ignore
assembly = self._protocol.response_from_deltas(self._partial_streaming_chunks)
self._partial_streaming_chunks.clear()
msg: SupportedMessages = self._protocol.message_from_response(assembly) # type: ignore
self.handle_on_response(msg)
self.add_to_memory(msg)
self._partial_streaming_chunks.clear()

tool_calls = self._protocol.tool_calls_from_message(msg)
if not tool_calls and not yielded:
yield self._protocol.sse_wrap(assembly.model_dump(mode='json')) #not delta and not tool, pass as is

if not self.tools or self._tool_fuse_broken or not tool_calls:
self._tool_fuse_broken = False
break # If there are no tool calls or tools available, exit the loop
else:
self._process_function_calls(tool_calls) # NOTE: no kwargs here as tool calls might need different parameters

if step == self.max_tool_calls - 2: # special case where we ran out of tool calls or stuck in a loop
self._tool_fuse_broken = True # one last attempt at graceful response without tools

self._tool_fuse_broken = False #defuse
yield self._protocol.done()

def query(self, query_input: SupportedMessages, **kwargs) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion core/just_agents/just_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __call__(cls, *args, **kwargs):
cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instances[cls]

SubscriberCallback = Callable[[...],None]
SubscriberCallback = Callable[[str,...],None]

class JustEventBus(metaclass=SingletonMeta):
"""
Expand Down
2 changes: 1 addition & 1 deletion core/just_agents/just_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __wrapper(*args, **kwargs):
bus = JustToolsBus()
bus.publish(f"{name}.call", args=args, kwargs=kwargs)
result = func(*args, **kwargs)
bus.publish(f"{name}.result", args=args, kwargs=kwargs, result=result)
bus.publish(f"{name}.result", result_interceptor=result)
return result
return __wrapper

Expand Down
3 changes: 1 addition & 2 deletions core/just_agents/protocols/litellm_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def delta_from_response(self, response: GenericStreamingChunk) -> MessageDict:
by_alias=True,
exclude={"function_call"} if not response.choices[0].delta.function_call else {} # failsafe
)
assert "function_call" not in message
return message

def content_from_delta(self, delta: MessageDict) -> str:
Expand All @@ -130,5 +129,5 @@ def tool_calls_from_message(self, message: MessageDict) -> List[LiteLLMFunctionC

def response_from_deltas(self, chunks: List[Any]) -> ModelResponse:
return stream_chunk_builder(chunks)
complete_response = litellm.stream_chunk_builder(chunks=chunks, messages=messages)
#complete_response = litellm.stream_chunk_builder(chunks=chunks, messages=messages)

38 changes: 24 additions & 14 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from dotenv import load_dotenv
import pytest

from typing import Callable, Any
from just_agents.base_agent import BaseAgent
from just_agents.llm_options import LLMOptions, LLAMA3_3, OPENAI_GPT4oMINI

from just_agents.just_tool import JustToolsBus
@pytest.fixture(scope="module", autouse=True)
def load_env():
load_dotenv(override=True)
Expand Down Expand Up @@ -52,7 +52,6 @@ def agent_call(prompt: str, options: LLMOptions, reconstruct_chunks: bool):

def test_stream():
result = agent_call("Why is the sky blue?", OPENAI_GPT4oMINI, False)
print(result)
assert "wavelength" in result

def test_stream_grok():
Expand All @@ -67,17 +66,28 @@ def test_stream_grok_recon():
result = agent_call("Why is the grass green?", LLAMA3_3, True)
assert "chlorophyll" in result

def test_tool_only():
def validate_tool_call(call : Callable[[Any,...],str],*args,**kwargs):
prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"
non_stream = agent_query(prompt,OPENAI_GPT4oMINI)
assert "72" in non_stream
assert "22" in non_stream
assert "10" in non_stream
bus = JustToolsBus()
results = []
result_callback = 'get_current_weather.result'
def callback(event_name: str, result_interceptor: str):
assert event_name == result_callback
results.append(result_interceptor)
bus.subscribe(result_callback,callback)
result = call(prompt,*args,**kwargs)
assert len(results) == 3
assert "72" in result
assert "22" in result
assert "10" in result
assert any('72' in item for item in results), "San Francisco weather call missing"
assert any('22' in item for item in results), "Paris weather call missing"
assert any('10' in item for item in results), "Tokyo weather call missing"

def test_query_tool():
validate_tool_call(agent_query, OPENAI_GPT4oMINI)

def test_stream_tool():
validate_tool_call(agent_call, OPENAI_GPT4oMINI, False)

#def test_stream_tool():
# prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"
# result = agent_call(prompt, OPENAI_GPT4oMINI, False)
# assert "72" in result
# assert "22" in result
# assert "10" in result

0 comments on commit e431395

Please sign in to comment.