Skip to content

Commit

Permalink
chat agent and pretty printing messages
Browse files Browse the repository at this point in the history
  • Loading branch information
antonkulaga committed Jan 9, 2025
1 parent a89f95d commit 8c2b00a
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 74 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,6 @@ cython_debug/

.micromamba/

*.egg-info
*.egg-info

config/agent_profiles.yaml #to deal with testing bug
59 changes: 0 additions & 59 deletions config/agent_profiles.yaml

This file was deleted.

54 changes: 52 additions & 2 deletions core/just_agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseAgent(
Note: it is based on pydantic and the only required field is llm_options.
However, it is also recommended to set system_prompt.
"""


# Core configuration for the LLM
llm_options: LLMOptions = Field(
Expand Down Expand Up @@ -137,7 +138,7 @@ def model_post_init(self, __context: Any) -> None:
print("Warning api_key will be rewritten by key_getter. Both are present in llm_options.")

# Initialize the agent with its system prompt
self.instruct(self.system_prompt)
self.instruct(self.system_prompt) # TODO: THIS CAUSES HUGE ISSUES WHEN YOU INHERIT FROM THIS CLASS FIX!!!!!!!!!!!!!!!!!!!!!!!!!!!1

def _prepare_options(self, options: LLMOptions):
opt = options.copy()
Expand Down Expand Up @@ -290,4 +291,53 @@ def model_supported_parameters(self) -> list[str]:
def supports_response_format(self) -> bool:
"""Checks if the current model supports the response_format parameter"""
#TODO: implement provider specific check
return "response_format" in self.model_supported_parameters
return "response_format" in self.model_supported_parameters



class ChatAgent(BaseAgent):
"""
An agent that has role/goal/task attributes and can call other agents
"""

role: Optional[str] = Field(default=None, description="Defines the agent's persona or identity")
goal: Optional[str] = Field (default=None, description="Specifies the agent's broader objective.")
task: Optional[str] = Field (default=None, description="Describes the specific task the agent is responsible for.")

delegation_prompt: Optional[str] = Field(default="You can delegate your task by calling the delegate function to the following agents:", description="Defines the prompt for the delegation")
delegates: Optional[list[str, BaseAgent]] = Field(default_factory=dict, description="Defines the list of agents that this agent can delegate to with descriptions")


def _update_system_promptform_prompt(self):
# Create a prompt incorporating role, goal, task
prompt = (
f"You are a {self.role}.\n"
f"Your goal is to {self.goal}.\n"
f"Your task is to {self.task}.\n"
"Respond appropriately."
)
return prompt

def model_post_init(self, __context: Any) -> None:
# Call parent's post_init to maintain core functionality
super().model_post_init(__context)
if self.system_prompt == self.DEFAULT_GENERIC_PROMPT:
self.system_prompt = ""

if self.role is not None:
self.system_prompt = self.system_prompt + "\n" + self.role
if self.goal is not None:
self.system_prompt = self.system_prompt + "\n" + self.goal
if self.task is not None:
self.system_prompt = self.system_prompt + "\n" + self.task
if len(self.delegates) > 0:
self.system_prompt = self.system_prompt + "\n" + self.delegation_prompt
self.system_prompt = self.system_prompt + "\n" + "\n".join([f"{agent.shortname}: {agent.description}" for agent in self.delegates.values()])
self.clear_memory()

@property
def delegates_dict(self) -> dict[str, BaseAgent]:
"""Returns a dictionary mapping agent shortnames to their corresponding BaseAgent instances"""
return {agent.shortname: agent for agent in self.delegates.values()}


90 changes: 87 additions & 3 deletions core/just_agents/base_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,96 @@
from just_agents.interfaces.memory import IMemory
from just_agents.types import Role, MessageDict, SupportedMessages
from litellm.types.utils import Function
from abc import ABC
from abc import ABC, abstractmethod

from typing import Optional
from rich.console import Console
from rich.text import Text
from rich.panel import Panel

class IMessageFormatter(ABC):
@abstractmethod
def pretty_print_message(self, msg: MessageDict) -> Panel:
pass

@abstractmethod
def pretty_print_all_messages(self):
pass

class MessageFormatter(IMessageFormatter):

messages: List[MessageDict] = Field(default_factory=list, validation_alias='conversation')


def pretty_print_message(self, msg: MessageDict) -> Panel:
role = msg.get('role', 'unknown').capitalize()

# If the role is an enum, extract its value
if isinstance(role, str) and '<Role.' in role:
role = role.split('.')[-1].replace('>', '').capitalize()

# Define role-specific colors
role_colors = {
'User': 'green',
'Assistant': 'blue',
'System': 'yellow',
'Function': 'magenta',
'Tool': 'magenta',
}
border_colors = {
'User': 'bright_green',
'Assistant': 'bright_blue',
'System': 'bright_yellow',
'Function': 'bright_magenta',
'Tool': 'bright_magenta',
}

# Get colors for the role (default to cyan/bright_yellow if role not found)
role_color = role_colors.get(role, 'cyan')
border_color = border_colors.get(role, 'bright_yellow')

# Create a title with bold text for the role
role_title = Text(f"[{role}]", style=f"bold {role_color}")

# Process tool call details if present
if 'tool_calls' in msg:
for tool_call in msg['tool_calls']:
tool_name = tool_call.get('function', {}).get('name', 'unknown tool')
arguments = tool_call.get('function', {}).get('arguments', '{}')
return Panel(
f"Tool Call to [bold magenta]{tool_name}[/bold magenta]:\n{arguments}",
title=role_title,
border_style=role_color,
)
elif 'tool_call_id' in msg:
tool_name = msg.get('name', 'unknown tool')
tool_result = msg.get('content', 'no content')
return Panel(
f"Response from [bold magenta]{tool_name}[/bold magenta]:\n{tool_result}",
title=role_title,
border_style=border_color,
)
else:
# Standard message
return Panel(
f"{msg.get('content', '')}",
title=role_title,
border_style=border_color,
)

def pretty_print_all_messages(self):
if not self.messages:
return

console = Console()
for msg in self.messages:
panel = self.pretty_print_message(msg)
console.print(panel)

OnMessageCallable = Callable[[MessageDict], None]
OnFunctionCallable = Callable[[Function], None]

class IBaseMemory(BaseModel, IMemory[Role, MessageDict], ABC):
class IBaseMemory(BaseModel, IMemory[Role, MessageDict], IMessageFormatter, ABC):
"""
Abstract Base Class to fulfill Pydantic schema requirements for concrete-attributes.
"""
Expand All @@ -27,7 +111,7 @@ class IBaseMemory(BaseModel, IMemory[Role, MessageDict], ABC):
def deepcopy(self) -> 'IBaseMemory':
return self.model_copy(deep=True)

class BaseMemory(IBaseMemory):
class BaseMemory(IBaseMemory, MessageFormatter):
"""
The Memory class provides storage and handling of messages for a language model session.
It supports adding, removing, and handling different types of messages and
Expand Down
12 changes: 12 additions & 0 deletions core/just_agents/just_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ class JustAgentProfile(JustSerializable):
description="A List[Callable] of tools s available to the agent and their descriptions")
"""A List[Callable] of tools s available to the agent and their descriptions"""

def _add_tool(self, fun: callable):
"""
Adds a tool to the agent's tools dictionary.
"""
tool = JustTool.from_callable(fun)
if self.tools is None:
self.tools = {
tool.name: tool
}
else:
self.tools[tool.name] = tool


@model_validator(mode='after')
def validate_model(self) -> 'JustAgentProfile':
Expand Down
9 changes: 8 additions & 1 deletion core/just_agents/just_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from pydantic import BaseModel, Field, field_validator, ValidationError
from collections.abc import MutableMapping, MutableSequence
from pydantic import ConfigDict


class JustYaml:
Expand Down Expand Up @@ -157,7 +158,7 @@ def save_to_yaml(
# to use with safe_dump:
yaml.representer.SafeRepresenter.add_representer(str, JustYaml.str_presenter)

class JustSerializable(BaseModel, extra="allow", use_enum_values=True, validate_assignment=True, populate_by_name=True):
class JustSerializable(BaseModel):
"""
Pydantic2 wrapper class that implements semi-automated YAML and JSON serialization and deserialization
Expand All @@ -167,6 +168,12 @@ class JustSerializable(BaseModel, extra="allow", use_enum_values=True, validate_
DEFAULT_SECTION_NAME (str): Default section name to use when none is provided.
"""
model_config = ConfigDict(
extra="allow",
use_enum_values=True,
validate_assignment=True,
populate_by_name=True
)
DEFAULT_CONFIG_PATH : ClassVar[Path] = Path('./config/default_config.yaml')
DEFAULT_PARENT_SECTION : ClassVar[Optional[str]] = None
DEFAULT_SECTION_NAME : ClassVar[Optional[str]] = "Agent" #'RenameMe'
Expand Down
6 changes: 5 additions & 1 deletion core/just_agents/just_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from just_agents.just_bus import JustEventBus
from importlib import import_module
import inspect
from pydantic import ConfigDict

FunctionParamFields=Literal["kind","default","type_annotation"]
FunctionParams = List[Dict[str, Dict[FunctionParamFields,Optional[str]]]]
Expand All @@ -17,7 +18,10 @@ class JustToolsBus(JustEventBus):
"""
pass

class LiteLLMDescription(BaseModel, populate_by_name=True):
class LiteLLMDescription(BaseModel):

model_config = ConfigDict(populate_by_name=True)

name: Optional[str] = Field(..., alias='function', description="The name of the function")
description: Optional[str] = Field(None, description="The docstring of the function.")
parameters: Optional[Dict[str,Any]]= Field(None, description="Parameters of the function.")
Expand Down
6 changes: 4 additions & 2 deletions core/just_agents/llm_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional
from pydantic import Field, HttpUrl, BaseModel

from pydantic import ConfigDict
LLMOptions = Dict[str, Any]

class ModelOptions(BaseModel):
Expand Down Expand Up @@ -38,7 +38,9 @@ class ModelOptions(BaseModel):
description="Frequency penalty, values from -2.0 to 2.0"
)

class LLMOptionsBase(ModelOptions, extra="allow"):
class LLMOptionsBase(ModelOptions):

model_config = ConfigDict(extra="allow")
api_key: Optional[str] = Field(None, examples=["sk-proj-...."])
api_base : Optional[HttpUrl] = Field(default=None,
examples=[
Expand Down
1 change: 1 addition & 0 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Deprecated = ">=1.2.15"
requests = "*"
numpydoc = "*"
python-dotenv = ">=1.0.1"
rich = ">=13.9.4"

[tool.poetry.group.dev.dependencies]
pytest = ">=8.3.4"
Expand Down
26 changes: 21 additions & 5 deletions examples/just_agents/examples/web/agent.yaml

Large diffs are not rendered by default.

Loading

0 comments on commit 8c2b00a

Please sign in to comment.