-
Notifications
You must be signed in to change notification settings - Fork 680
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into memory_liubo
- Loading branch information
Showing
15 changed files
with
499 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
import re | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from tenacity import retry, stop_after_attempt, wait_exponential | ||
|
||
from camel.agents import ChatAgent | ||
from camel.messages import BaseMessage | ||
from camel.prompts import TextPrompt | ||
from camel.typing import ModelType, RoleType | ||
|
||
|
||
class RoleAssignmentAgent(ChatAgent): | ||
r"""An agent that generates role names based on the task prompt. | ||
Attributes: | ||
role_assignment_prompt (TextPrompt): A prompt for the agent to generate | ||
role names. | ||
Args: | ||
model (ModelType, optional): The type of model to use for the agent. | ||
(default: :obj:`ModelType.GPT_3_5_TURBO`) | ||
model_config (Any, optional): The configuration for the model. | ||
(default: :obj:`None`) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: ModelType = ModelType.GPT_3_5_TURBO, | ||
model_config: Optional[Any] = None, | ||
) -> None: | ||
system_message = BaseMessage( | ||
role_name="Role Assigner", | ||
role_type=RoleType.ASSISTANT, | ||
meta_dict=None, | ||
content="You assign roles based on tasks.", | ||
) | ||
super().__init__(system_message, model, model_config) | ||
|
||
@retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5)) | ||
def run( | ||
self, | ||
task_prompt: Union[str, TextPrompt], | ||
num_roles: int = 2, | ||
) -> Dict[str, str]: | ||
r"""Generate role names based on the input task prompt. | ||
Args: | ||
task_prompt (Union[str, TextPrompt]): The prompt | ||
for the task based on which the roles are to be generated. | ||
num_roles (int, optional): The number of roles to generate. | ||
(default: :obj:`2`) | ||
Returns: | ||
Dict[str, str]: A dictionary mapping role names to their | ||
descriptions. | ||
""" | ||
self.reset() | ||
|
||
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join( | ||
f"Domain expert {i + 1}: <BLANK>\n" | ||
f"Associated competencies, characteristics, duties " | ||
f"and workflows: <BLANK>. End." for i in range(num_roles or 0)) | ||
role_assignment_generation_prompt = TextPrompt( | ||
"You are a role assignment agent, and you're in charge of " + | ||
"recruiting {num_roles} experts for the following task." + | ||
"\n==== TASK =====\n {task}\n\n" + | ||
"Identify the domain experts you'd recruit and detail their " + | ||
"associated competencies, characteristics, duties and workflows " + | ||
"to complete the task.\n " + | ||
"Your answer MUST adhere to the format of ANSWER PROMPT, and " + | ||
"ONLY answer the BLANKs.\n" + expert_prompt) | ||
role_assignment_generation = role_assignment_generation_prompt.format( | ||
num_roles=num_roles, task=task_prompt) | ||
|
||
role_assignment_generation_msg = BaseMessage.make_user_message( | ||
role_name="Role Assigner", content=role_assignment_generation) | ||
|
||
response = self.step(input_message=role_assignment_generation_msg) | ||
|
||
msg = response.msg # type: BaseMessage | ||
terminated = response.terminated | ||
|
||
# Distribute the output completions into role names and descriptions | ||
role_names = [ | ||
desc.replace("<|", "").replace("|>", "") for desc in re.findall( | ||
r"Domain expert \d: (.+?)\nAssociated competencies,", | ||
msg.content, | ||
re.DOTALL, | ||
) | ||
] | ||
role_descriptions = [ | ||
desc.replace("<|", "").replace("|>", "") for desc in re.findall( | ||
r"Associated competencies, characteristics, " | ||
r"duties and workflows: (.+?) End.", msg.content, re.DOTALL) | ||
] | ||
|
||
if len(role_names) != num_roles or len(role_descriptions) != num_roles: | ||
raise RuntimeError( | ||
"Got None or insufficient information of roles.") | ||
if terminated: | ||
raise RuntimeError("Role assignment failed.") | ||
|
||
role_descriptions_dict = { | ||
role_name: description | ||
for role_name, description in zip(role_names, role_descriptions) | ||
} | ||
|
||
return role_descriptions_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from typing import Any | ||
|
||
from camel.prompts import AISocietyPromptTemplateDict, TextPrompt | ||
from camel.typing import RoleType | ||
|
||
|
||
# flake8: noqa :E501 | ||
class RoleDescriptionPromptTemplateDict(AISocietyPromptTemplateDict): | ||
r"""A dictionary containing :obj:`TextPrompt` used in the `role description` | ||
task. | ||
Attributes: | ||
ROLE_DESCRIPTION_PROMPT (TextPrompt): A default prompt to | ||
describe the role descriptions. | ||
ASSISTANT_PROMPT (TextPrompt): A system prompt for the AI assistant | ||
that outlines the rules of the conversation and provides | ||
instructions for completing tasks. | ||
USER_PROMPT (TextPrompt): A system prompt for the AI user that | ||
outlines the rules of the conversation and provides instructions | ||
for giving instructions to the AI assistant. | ||
""" | ||
ROLE_DESCRIPTION_PROMPT = TextPrompt("""===== ROLES WITH DESCRIPTION ===== | ||
{user_role} and {assistant_role} are collaborating to complete a task: {task}. | ||
Competencies, characteristics, duties and workflows of {user_role} to complete the task: {user_description} | ||
{assistant_role}'s competencies, characteristics, duties and workflows to complete the task: {assistant_description} | ||
""") | ||
|
||
ASSISTANT_PROMPT = TextPrompt(ROLE_DESCRIPTION_PROMPT + | ||
AISocietyPromptTemplateDict.ASSISTANT_PROMPT) | ||
|
||
USER_PROMPT = TextPrompt(ROLE_DESCRIPTION_PROMPT + | ||
AISocietyPromptTemplateDict.USER_PROMPT) | ||
|
||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.update({ | ||
"role_description": self.ROLE_DESCRIPTION_PROMPT, | ||
RoleType.ASSISTANT: self.ASSISTANT_PROMPT, | ||
RoleType.USER: self.USER_PROMPT, | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from colorama import Fore | ||
|
||
from camel.agents import RoleAssignmentAgent | ||
from camel.configs import ChatGPTConfig | ||
|
||
|
||
def main(model_type=None, num_roles=3) -> None: | ||
task_prompt = "Develop a trading bot for the stock market." | ||
|
||
model_config_description = ChatGPTConfig() | ||
role_description_agent = RoleAssignmentAgent( | ||
model=model_type, model_config=model_config_description) | ||
|
||
role_description_dict = role_description_agent.run(task_prompt=task_prompt, | ||
num_roles=num_roles) | ||
|
||
if (len(role_description_dict) != num_roles): | ||
raise ValueError( | ||
f"Length of role_names ({len(role_description_dict)}) " | ||
f"does not equal to num_roles ({num_roles}).") | ||
|
||
print(Fore.YELLOW + f"Original task prompt:\n{task_prompt}\n") | ||
print(Fore.GREEN + f"List of {num_roles} roles with description:") | ||
for role_name in role_description_dict.keys(): | ||
print(Fore.BLUE + f"{role_name}:\n" | ||
f"{role_description_dict[role_name]}\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.