Skip to content

Commit

Permalink
Merge pull request #8 from winternewt/gse-backport
Browse files Browse the repository at this point in the history
Gse backport
  • Loading branch information
antonkulaga authored Jan 3, 2025
2 parents c679921 + d0a7bad commit 6ab09df
Show file tree
Hide file tree
Showing 50 changed files with 2,094 additions and 744 deletions.
4 changes: 3 additions & 1 deletion coding/containers/biosandbox/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ dependencies:
- genomepy>=0.16.1
- pyensembl
- plotly
- GEOparse>=2.0.4
- GEOparse>=2.0.4
- pybiomart
- scanpy
28 changes: 27 additions & 1 deletion coding/just_agents/coding/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,33 @@ def run_bash_command(command: str):
with MicromambaSession(image="ghcr.io/longevity-genie/just-agents/biosandbox:main", lang="python", keep_template=True, verbose=True) as session:
result: ConsoleOutput = session.execute_command(command=command)
return result


def validate_python_code_syntax(code: str, filename: str)-> str:
"""
code: str # python code to validate
filename: str # a filename to use in error messages
"""
try:
# Compile the code string to check for syntax errors
compiled_code = compile(code, f"/example/{filename}", "exec")
return ("Code syntax is correct")
except SyntaxError as e:
return (f"Syntax error in code: {e}")

def save_text_to_runtime(text: str, filename: str):
"""
text: str # ptext to be saved
filename: str # a filename to use i
"""
with MicromambaSession(image="ghcr.io/longevity-genie/just-agents/biosandbox:main", lang="python", keep_template=True, verbose=True) as session:

text_file = f"/tmp/{filename}"
dest_file = f"/example/{filename}"
with open(text_file, "w") as f:
f.write(text)
result: ConsoleOutput = session.copy_to_runtime(src=text_file, dest=dest_file)
return result


def run_python_code(code: str):
"""
Expand Down
65 changes: 39 additions & 26 deletions core/just_agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from pydantic import Field, PrivateAttr
from typing import Optional, List, Union, Any, Generator

from just_agents.core.interfaces.IMemory import IMemory
from just_agents.core.types import Role, AbstractMessage, SupportedMessages, SupportedMessage
from just_agents.types import Role, SupportedMessages

from just_agents.llm_options import LLMOptions
from just_agents.streaming.protocols.interfaces.IFunctionCall import IFunctionCall
from just_agents.streaming.protocols.interfaces.IProtocolAdapter import IProtocolAdapter, BaseModelResponse
from just_agents.core.interfaces.IAgent import IAgentWithInterceptors, QueryListener, ResponseListener
from just_agents.interfaces.function_call import IFunctionCall
from just_agents.interfaces.protocol_adapter import IProtocolAdapter, BaseModelResponse
from just_agents.interfaces.agent import IAgentWithInterceptors, QueryListener, ResponseListener

from just_agents.base_memory import IBaseMemory, BaseMemory
from just_agents.just_profile import JustAgentProfile
from just_agents.core.rotate_keys import RotateKeys
from just_agents.streaming.protocol_factory import StreamingMode, ProtocolAdapterFactory
from just_agents.rotate_keys import RotateKeys
from just_agents.protocols.protocol_factory import StreamingMode, ProtocolAdapterFactory
from litellm.litellm_core_utils.get_supported_openai_params import get_supported_openai_params


Expand Down Expand Up @@ -64,10 +62,16 @@ class BaseAgent(
key_list_path: Optional[str] = Field(
default=None,
exclude=True,
description="path to text file with list of api keys, one key per line")
description="Path to text file with list of api keys, one key per line")

max_tool_calls: int = Field(
ge=1,
default=50,
description="A safeguard to prevent tool calls stuck in a loop")

drop_params: bool = Field(
default=True,
description=" drop params from the request, useful for some models that do not support them")
description="Drop params from the request, useful for some models that do not support them")

# Protected handlers implementation
_on_query : List[QueryListener] = PrivateAttr(default_factory=list)
Expand All @@ -78,7 +82,7 @@ class BaseAgent(
_partial_streaming_chunks: List[BaseModelResponse] = PrivateAttr(
default_factory=list) # Buffers streaming responses
_key_getter: Optional[RotateKeys] = PrivateAttr(None) # Manages API key rotation

_tool_fuse_broken: bool = PrivateAttr(False) #Fuse to prevent tool loops

def instruct(self, prompt: str): #backward compatibility
self.memory.add_message({"role": Role.system, "content": prompt})
Expand All @@ -87,7 +91,7 @@ def clear_memory(self) -> None:
self.memory.clear_messages()
self.instruct(self.system_prompt)

def deepcopy_memory(self) -> IMemory:
def deepcopy_memory(self) -> IBaseMemory:
return self.memory.deepcopy()

def add_to_memory(self, messages: SupportedMessages) -> None:
Expand Down Expand Up @@ -129,7 +133,7 @@ def model_post_init(self, __context: Any) -> None:

def _prepare_options(self, options: LLMOptions):
opt = options.copy()
if self.tools is not None: # populate llm_options based on available tools
if self.tools is not None and not self._tool_fuse_broken: # populate llm_options based on available tools
opt["tools"] = [{"type": "function",
"function": self.tools[tool].get_litellm_description()} for tool in self.tools]
return opt
Expand All @@ -138,7 +142,7 @@ def _execute_completion(
self,
stream: bool,
**kwargs
) -> Union[AbstractMessage, BaseModelResponse]:
) -> Union[SupportedMessages, BaseModelResponse]:

opt = self._prepare_options(self.llm_options)
opt.update(kwargs)
Expand Down Expand Up @@ -172,7 +176,7 @@ def _execute_completion(
return self._protocol.completion(messages=self.memory.messages, stream=stream, **opt)


def _process_function_calls(self, function_calls: List[IFunctionCall[AbstractMessage]]) -> SupportedMessages:
def _process_function_calls(self, function_calls: List[IFunctionCall[SupportedMessages]]) -> SupportedMessages:
messages: SupportedMessages = []
for call in function_calls:
msg = call.execute_function(lambda function_name: self.tools[function_name].get_callable())
Expand All @@ -182,54 +186,63 @@ def _process_function_calls(self, function_calls: List[IFunctionCall[AbstractMes
return messages

def query_with_current_memory(self, **kwargs): #former proceed() aka llm_think()
while True:
for step in range(self.max_tool_calls):
# individual llm call, unpacking the message, processing handlers
response = self._execute_completion(stream=False, **kwargs)
msg: AbstractMessage = self._protocol.message_from_response(response) # type: ignore
msg: SupportedMessage = self._protocol.message_from_response(response) # type: ignore
self.handle_on_response(msg)
self.add_to_memory(msg)

if not self.tools:
if not self.tools or self._tool_fuse_broken:
self._tool_fuse_broken = False
break
# If there are no tool calls or tools available, exit the loop
tool_calls = self._protocol.tool_calls_from_message(msg)
# Process each tool call if they exist and re-execute query
self._process_function_calls(
tool_calls) # NOTE: no kwargs here as tool calls might need different parameters

if not tool_calls:
break
# Process each tool call if they exist and re-execute query
self._process_function_calls(tool_calls) #NOTE: no kwargs here as tool calls might need different parameters
elif step == self.max_tool_calls - 2: #special case where we ran out of tool calls or stuck in a loop
self._tool_fuse_broken = True #one last attempt at graceful response


def streaming_query_with_current_memory(self, reconstruct_chunks = False, **kwargs):
try:
self._partial_streaming_chunks.clear()
while True: #TODO rewrite this super-ugly while-True-break stuff into proper recursion
for step in range(self.max_tool_calls):
response = self._execute_completion(stream=True, **kwargs)
tool_messages: list[AbstractMessage] = []
tool_messages: list[SupportedMessages] = []
for i, part in enumerate(response):
self._partial_streaming_chunks.append(part)
msg: AbstractMessage = self._protocol.message_from_delta(response) # type: ignore
msg: SupportedMessage = self._protocol.message_from_delta(response) # type: ignore
delta = self._protocol.content_from_delta(msg)
if delta:
if reconstruct_chunks:
yield self._protocol.get_chunk(i, delta, options={'model': part["model"]})
else:
yield response
if self.tools:
if self.tools and not self._tool_fuse_broken:
tool_calls = self._protocol.tool_calls_from_message(msg)
if tool_calls:
self.add_to_memory(
self._protocol.function_convention.reconstruct_tool_call_message(tool_calls)
)
self._process_function_calls(tool_calls)
tool_messages.append(self._process_function_calls(tool_calls))

if not tool_messages:
break
elif step == self.max_tool_calls - 2: # special case where we ran out of tool calls or stuck in a loop
self._tool_fuse_broken = True # one last attempt at graceful response

finally:
self._tool_fuse_broken = False #defuse
yield self._protocol.done()
if len(self._partial_streaming_chunks) > 0:
response = self._protocol.response_from_deltas(self._partial_streaming_chunks)
msg: AbstractMessage = self._protocol.message_from_response(response) # type: ignore
msg: SupportedMessage = self._protocol.message_from_response(response) # type: ignore
self.handle_on_response(msg)
self.add_to_memory(msg)
self._partial_streaming_chunks.clear()
Expand All @@ -248,7 +261,7 @@ def query(self, query_input: SupportedMessages, **kwargs) -> str:
return result


def stream(self, query_input: SupportedMessages, reconstruct_chunks = False, **kwargs) -> Generator[Union[BaseModelResponse, AbstractMessage],None,None]:
def stream(self, query_input: SupportedMessages, reconstruct_chunks = False, **kwargs) -> Generator[Union[BaseModelResponse, SupportedMessages],None,None]:
self.handle_on_query(query_input)
self.add_to_memory(query_input)
return self.streaming_query_with_current_memory(reconstruct_chunks=reconstruct_chunks, **kwargs)
Expand Down
23 changes: 11 additions & 12 deletions core/just_agents/base_memory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, Callable, List, Dict, Union
from typing import Optional, Callable, List, Dict
from functools import singledispatchmethod
from just_agents.core.interfaces.IMemory import IMemory
from just_agents.core.types import Role, AbstractMessage, SupportedMessages, SupportedMessage
from just_agents.interfaces.memory import IMemory
from just_agents.types import Role, MessageDict, SupportedMessages
from litellm.types.utils import Function
from abc import ABC

OnMessageCallable = Callable[[AbstractMessage], None]
OnMessageCallable = Callable[[MessageDict], None]
OnFunctionCallable = Callable[[Function], None]

class IBaseMemory(BaseModel, IMemory[Role, AbstractMessage], ABC):
class IBaseMemory(BaseModel, IMemory[Role, MessageDict], ABC):
"""
Abstract Base Class to fulfill Pydantic schema requirements for concrete-attributes.
"""

messages: List[AbstractMessage] = Field(default_factory=list, validation_alias='conversation')
messages: List[MessageDict] = Field(default_factory=list, validation_alias='conversation')

# Private dict of message handlers for each role
_on_message: Dict[Role, List[OnMessageCallable]] = PrivateAttr(default_factory=lambda: {
Expand All @@ -24,14 +24,17 @@ class IBaseMemory(BaseModel, IMemory[Role, AbstractMessage], ABC):
Role.system: [],
})

def deepcopy(self) -> 'IBaseMemory':
return self.model_copy(deep=True)

class BaseMemory(IBaseMemory):
"""
The Memory class provides storage and handling of messages for a language model session.
It supports adding, removing, and handling different types of messages and
function calls categorized by roles: assistant, tool, user, and system.
"""

def handle_message(self, message: AbstractMessage) -> None:
def handle_message(self, message: MessageDict) -> None:
"""
Implements the abstract method to handle messages based on their roles.
"""
Expand Down Expand Up @@ -92,7 +95,7 @@ def add_on_tool_call(self, fun: OnFunctionCallable) -> None:
Adds a handler to track function calls.
"""

def tool_handler(message: AbstractMessage) -> None:
def tool_handler(message: MessageDict) -> None:
tool_calls = message.get('tool_calls', [])
for call in tool_calls:
function_name = call.get('function')
Expand Down Expand Up @@ -169,7 +172,3 @@ def remove_on_system_message(self, handler: OnMessageCallable) -> None:
"""
self._remove_on_message(handler, Role.system)

def deepcopy(self) -> 'BaseMemory':
return self.model_copy(deep=True)


41 changes: 0 additions & 41 deletions core/just_agents/config/agent_profiles.yaml

This file was deleted.

68 changes: 0 additions & 68 deletions core/just_agents/core/types.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 6ab09df

Please sign in to comment.