Skip to content

Commit

Permalink
Removed all derived from BaseMessage. Removed role field from BaseMes…
Browse files Browse the repository at this point in the history
…sage
  • Loading branch information
Obs01ete committed Jun 20, 2023
1 parent 1a08bbe commit 4f38336
Show file tree
Hide file tree
Showing 33 changed files with 239 additions and 517 deletions.
13 changes: 6 additions & 7 deletions apps/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from apps.agents.text_utils import split_markdown_code
from camel.agents import TaskSpecifyAgent
from camel.messages import AssistantChatMessage
from camel.messages import BaseMessage
from camel.societies import RolePlaying

REPO_ROOT = os.path.realpath(
Expand All @@ -43,17 +43,16 @@ class State:
session: Optional[RolePlaying]
max_messages: int
chat: ChatBotHistory
saved_assistant_msg: Optional[AssistantChatMessage]
saved_assistant_msg: Optional[BaseMessage]

@classmethod
def empty(cls) -> 'State':
return cls(None, 0, [], None)

@staticmethod
def construct_inplace(
state: 'State', session: Optional[RolePlaying], max_messages: int,
chat: ChatBotHistory,
saved_assistant_msg: Optional[AssistantChatMessage]) -> None:
def construct_inplace(state: 'State', session: Optional[RolePlaying],
max_messages: int, chat: ChatBotHistory,
saved_assistant_msg: Optional[BaseMessage]) -> None:
state.session = session
state.max_messages = max_messages
state.chat = chat
Expand Down Expand Up @@ -216,7 +215,7 @@ def role_playing_chat_init(state) -> \

try:
init_assistant_msg, _ = session.init_chat()
init_assistant_msg: AssistantChatMessage
init_assistant_msg: BaseMessage
except (openai.error.RateLimitError, tenacity.RetryError,
RuntimeError) as ex:
print("OpenAI API exception 1 " + str(ex))
Expand Down
60 changes: 38 additions & 22 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from camel.agents import BaseAgent
from camel.configs import ChatGPTConfig
from camel.messages import ChatMessage, MessageType, SystemMessage
from camel.messages import BaseMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import num_tokens_from_messages, openai_api_key_required
Expand All @@ -31,15 +31,15 @@ class ChatAgentResponse:
r"""Response of a ChatAgent.
Attributes:
msgs (List[ChatMessage]): A list of zero, one or several messages.
msgs (List[BaseMessage]): A list of zero, one or several messages.
If the list is empty, there is some error in message generation.
If the list has one message, this is normal mode.
If the list has several messages, this is the critic mode.
terminated (bool): A boolean indicating whether the agent decided
to terminate the chat session.
info (Dict[str, Any]): Extra information about the chat message.
"""
msgs: List[ChatMessage]
msgs: List[BaseMessage]
terminated: bool
info: Dict[str, Any]

Expand All @@ -51,11 +51,17 @@ def msg(self):
return self.msgs[0]


@dataclass(frozen=True)
class ChatRecord:
role_at_backend: str
message: BaseMessage


class ChatAgent(BaseAgent):
r"""Class for managing conversations of CAMEL Chat Agents.
Args:
system_message (SystemMessage): The system message for the chat agent.
system_message (BaseMessage): The system message for the chat agent.
model (ModelType, optional): The LLM model to use for generating
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
model_config (Any, optional): Configuration options for the LLM model.
Expand All @@ -67,13 +73,13 @@ class ChatAgent(BaseAgent):

def __init__(
self,
system_message: SystemMessage,
system_message: BaseMessage,
model: Optional[ModelType] = None,
model_config: Optional[Any] = None,
message_window_size: Optional[int] = None,
) -> None:

self.system_message: SystemMessage = system_message
self.system_message: BaseMessage = system_message
self.role_name: str = system_message.role_name
self.role_type: RoleType = system_message.role_type

Expand All @@ -87,14 +93,15 @@ def __init__(
self.model_token_limit: int = self.model_backend.token_limit

self.terminated: bool = False
self.stored_messages: List[ChatRecord]
self.init_messages()

def reset(self) -> List[MessageType]:
def reset(self) -> List[ChatRecord]:
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
stored messages.
Returns:
List[MessageType]: The stored messages.
List[BaseMessage]: The stored messages.
"""
self.terminated = False
self.init_messages()
Expand Down Expand Up @@ -131,32 +138,38 @@ def init_messages(self) -> None:
r"""Initializes the stored messages list with the initial system
message.
"""
self.stored_messages: List[MessageType] = [self.system_message]
self.stored_messages = [ChatRecord('system', self.system_message)]

def update_messages(self, message: ChatMessage) -> List[MessageType]:
def update_messages(self, role: str,
message: BaseMessage) -> List[ChatRecord]:
r"""Updates the stored messages list with a new message.
Args:
message (ChatMessage): The new message to add to the stored
message (BaseMessage): The new message to add to the stored
messages.
Returns:
List[ChatMessage]: The updated stored messages.
List[BaseMessage]: The updated stored messages.
"""
self.stored_messages.append(message)
if role not in {'system', 'user', 'assistant'}:
raise ValueError(f"Unsupported role {role}")
self.stored_messages.append(ChatRecord(role, message))
return self.stored_messages

def submit_critic_choice(self, message: BaseMessage) -> None:
self.stored_messages.append(ChatRecord('assistant', message))

@retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5))
@openai_api_key_required
def step(
self,
input_message: ChatMessage,
input_message: BaseMessage,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message.
Args:
input_message (ChatMessage): The input message to the agent.
input_message (BaseMessage): The input message to the agent.
Its `role` field that specifies the role at backen may be either
`user` or `assistant` but it will be set to `user` anyway since
for the self agent any incoming message is external.
Expand All @@ -167,25 +180,28 @@ def step(
the chat session has terminated, and information about the chat
session.
"""
msg_user_at_backend = input_message.set_user_role_at_backend()
messages = self.update_messages(msg_user_at_backend)
messages = self.update_messages('user', input_message)
if self.message_window_size is not None and len(
messages) > self.message_window_size:
messages = [self.system_message
messages = [ChatRecord('system', self.system_message)
] + messages[-self.message_window_size:]
openai_messages = [message.to_openai_message() for message in messages]
openai_messages = [
record.message.to_openai_message(record.role_at_backend)
for record in messages
]
num_tokens = num_tokens_from_messages(openai_messages, self.model)

output_messages: Optional[List[ChatMessage]]
output_messages: Optional[List[BaseMessage]]
info: Dict[str, Any]

if num_tokens < self.model_token_limit:
response = self.model_backend.run(openai_messages)
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
ChatMessage(role_name=self.role_name, role_type=self.role_type,
meta_dict=dict(), **dict(choice["message"]))
BaseMessage(role_name=self.role_name, role_type=self.role_type,
meta_dict=dict(),
content=choice["message"]['content'])
for choice in response["choices"]
]
info = self.get_info(
Expand Down
34 changes: 16 additions & 18 deletions camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from colorama import Fore

from camel.agents import ChatAgent
from camel.messages import ChatMessage, SystemMessage
from camel.messages import BaseMessage
from camel.typing import ModelType
from camel.utils import get_first_int, print_text_animated

Expand All @@ -28,7 +28,7 @@ class CriticAgent(ChatAgent):
r"""A class for the critic agent that assists in selecting an option.
Args:
system_message (SystemMessage): The system message for the critic
system_message (BaseMessage): The system message for the critic
agent.
model (ModelType, optional): The LLM model to use for generating
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
Expand All @@ -46,7 +46,7 @@ class CriticAgent(ChatAgent):

def __init__(
self,
system_message: SystemMessage,
system_message: BaseMessage,
model: ModelType = ModelType.GPT_3_5_TURBO,
model_config: Optional[Any] = None,
message_window_size: int = 6,
Expand All @@ -61,11 +61,11 @@ def __init__(
self.verbose = verbose
self.logger_color = logger_color

def flatten_options(self, messages: Sequence[ChatMessage]) -> str:
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
r"""Flattens the options to the critic.
Args:
messages (Sequence[ChatMessage]): A list of `ChatMessage` objects.
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
Returns:
str: A string containing the flattened options to the critic.
Expand All @@ -83,11 +83,11 @@ def flatten_options(self, messages: Sequence[ChatMessage]) -> str:
"and then your explanation and comparison: ")
return flatten_options + format

def get_option(self, input_message: ChatMessage) -> str:
def get_option(self, input_message: BaseMessage) -> str:
r"""Gets the option selected by the critic.
Args:
input_message (ChatMessage): A `ChatMessage` object representing
input_message (BaseMessage): A `BaseMessage` object representing
the input message.
Returns:
Expand All @@ -104,8 +104,8 @@ def get_option(self, input_message: ChatMessage) -> str:
if critic_response.terminated:
raise RuntimeError("Critic step failed.")

critic_msg = critic_response.msgs[0]
self.update_messages(critic_msg)
critic_msg = critic_response.msg
self.update_messages('assistant', critic_msg)
if self.verbose:
print_text_animated(self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n")
Expand All @@ -114,11 +114,10 @@ def get_option(self, input_message: ChatMessage) -> str:
if choice in self.options_dict:
return self.options_dict[choice]
else:
input_message = ChatMessage(
input_message = BaseMessage(
role_name=input_message.role_name,
role_type=input_message.role_type,
meta_dict=input_message.meta_dict,
role=input_message.role,
content="> Invalid choice. Please choose again.\n" +
msg_content,
)
Expand All @@ -128,11 +127,11 @@ def get_option(self, input_message: ChatMessage) -> str:
"Returning a random option.")
return random.choice(list(self.options_dict.values()))

def parse_critic(self, critic_msg: ChatMessage) -> Optional[str]:
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
r"""Parses the critic's message and extracts the choice.
Args:
critic_msg (ChatMessage): A `ChatMessage` object representing the
critic_msg (BaseMessage): A `BaseMessage` object representing the
critic's response.
Returns:
Expand All @@ -142,22 +141,21 @@ def parse_critic(self, critic_msg: ChatMessage) -> Optional[str]:
choice = str(get_first_int(critic_msg.content))
return choice

def step(self, messages: Sequence[ChatMessage]) -> ChatMessage:
def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
r"""Performs one step of the conversation by flattening options to the
critic, getting the option, and parsing the choice.
Args:
messages (Sequence[ChatMessage]): A list of ChatMessage objects.
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
Returns:
ChatMessage: A `ChatMessage` object representing the critic's
BaseMessage: A `BaseMessage` object representing the critic's
choice.
"""
meta_chat_message = ChatMessage(
meta_chat_message = BaseMessage(
role_name=messages[0].role_name,
role_type=messages[0].role_type,
meta_dict=messages[0].meta_dict,
role=messages[0].role,
content="",
)

Expand Down
19 changes: 9 additions & 10 deletions camel/agents/embodied_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from colorama import Fore

from camel.agents import BaseToolAgent, ChatAgent, HuggingFaceToolAgent
from camel.messages import ChatMessage, SystemMessage
from camel.messages import BaseMessage
from camel.typing import ModelType
from camel.utils import print_text_animated

Expand All @@ -25,7 +25,7 @@ class EmbodiedAgent(ChatAgent):
r"""Class for managing conversations of CAMEL Embodied Agents.
Args:
system_message (SystemMessage): The system message for the chat agent.
system_message (BaseMessage): The system message for the chat agent.
model (ModelType, optional): The LLM model to use for generating
responses. (default :obj:`ModelType.GPT_4`)
model_config (Any, optional): Configuration options for the LLM model.
Expand All @@ -42,7 +42,7 @@ class EmbodiedAgent(ChatAgent):

def __init__(
self,
system_message: SystemMessage,
system_message: BaseMessage,
model: ModelType = ModelType.GPT_4,
model_config: Optional[Any] = None,
message_window_size: Optional[int] = None,
Expand Down Expand Up @@ -79,15 +79,15 @@ def get_action_space_prompt(self) -> str:

def step(
self,
input_message: ChatMessage,
) -> Tuple[ChatMessage, bool, Dict[str, Any]]:
input_message: BaseMessage,
) -> Tuple[BaseMessage, bool, Dict[str, Any]]:
r"""Performs a step in the conversation.
Args:
input_message (ChatMessage): The input message.
input_message (BaseMessage): The input message.
Returns:
Tuple[ChatMessage, bool, Dict[str, Any]]: A tuple
Tuple[BaseMessage, bool, Dict[str, Any]]: A tuple
containing the output messages, termination status, and
additional information.
"""
Expand Down Expand Up @@ -126,7 +126,6 @@ def step(
# TODO: Handle errors
content = input_message.content + (Fore.RESET +
f"\n> Embodied Actions:\n{content}")
message = ChatMessage(input_message.role_name, input_message.role_type,
input_message.meta_dict, input_message.role,
content)
message = BaseMessage(input_message.role_name, input_message.role_type,
input_message.meta_dict, content)
return message, response.terminated, response.info
Loading

0 comments on commit 4f38336

Please sign in to comment.