diff --git a/just_agents/config/cot_agent_prompt.yaml b/just_agents/config/cot_agent_prompt.yaml index 2e0793a..913c171 100644 --- a/just_agents/config/cot_agent_prompt.yaml +++ b/just_agents/config/cot_agent_prompt.yaml @@ -1,3 +1,4 @@ +class: "ChainOfThoughtAgent" system_prompt: "You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. @@ -29,6 +30,9 @@ action_final: "final_answer" thought_max_tokes: 300 max_steps: 25 final_max_tokens: 1200 +tools: +# - package: +# function: llm_session: just_streaming_method: "openai" completion_remove_key_on_error: True diff --git a/just_agents/config/llm_session_schema.yaml b/just_agents/config/llm_session_schema.yaml index 201e137..7319331 100644 --- a/just_agents/config/llm_session_schema.yaml +++ b/just_agents/config/llm_session_schema.yaml @@ -1,7 +1,11 @@ +class: "LLMSession" just_streaming_method: "openai" system_prompt_path: system_prompt: completion_remove_key_on_error: True completion_max_tries: 2 backup_options: -key_list_path: \ No newline at end of file +key_list_path: +tools: +# - package: +# function: \ No newline at end of file diff --git a/just_agents/cot_agent.py b/just_agents/cot_agent.py index e5f827e..4e91388 100644 --- a/just_agents/cot_agent.py +++ b/just_agents/cot_agent.py @@ -1,11 +1,11 @@ -from just_agents.interfaces.IAddAllMessages import IAddAllMessages +from just_agents.interfaces.IAgent import IAgent from just_agents.llm_session import LLMSession import json from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol from pathlib import Path, PurePath import yaml -from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt +from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools # schema parameters: LLM_SESSION = "llm_session" @@ -17,19 +17,18 @@ FINAL_MAX_TOKENS = "final_max_tokens" FINAL_PROMPT = "final_prompt" -# -class ChainOfThoughtAgent(): +class ChainOfThoughtAgent(IAgent): def __init__(self, llm_options: dict = None, agent_schema: str | Path | dict | None = None, tools: list = None, output_streaming:AbstractStreamingProtocol = OpenaiStreamingProtocol()): self.agent_schema: dict = resolve_agent_schema(agent_schema, "ChainOfThoughtAgent", "cot_agent_prompt.yaml") + if tools is None: + tools = resolve_tools(self.agent_schema) self.session: LLMSession = LLMSession(llm_options=resolve_llm_options(self.agent_schema, llm_options), + system_prompt=resolve_system_prompt(self.agent_schema), agent_schema=self.agent_schema.get(LLM_SESSION, None), tools=tools) - system_prompt = resolve_system_prompt(self.agent_schema) - if system_prompt is not None: - self.session.instruct(system_prompt) self.output_streaming: AbstractStreamingProtocol = output_streaming diff --git a/just_agents/interfaces/IAddAllMessages.py b/just_agents/interfaces/IAddAllMessages.py deleted file mode 100644 index 849236b..0000000 --- a/just_agents/interfaces/IAddAllMessages.py +++ /dev/null @@ -1,7 +0,0 @@ -class IAddAllMessages: - - def stream_add_all(self, messages: list): - raise NotImplementedError("You need to impelement stream_add_all() first!") - - def query_add_all(self, messages: list[dict]) -> str: - raise NotImplementedError("You need to impelement query_add_all() first!") \ No newline at end of file diff --git a/just_agents/interfaces/IAgent.py b/just_agents/interfaces/IAgent.py new file mode 100644 index 0000000..0a1aa93 --- /dev/null +++ b/just_agents/interfaces/IAgent.py @@ -0,0 +1,44 @@ +from typing import AsyncGenerator, Any + +def build_agent(agent_schema: dict): + from just_agents.cot_agent import ChainOfThoughtAgent + from just_agents.llm_session import LLMSession + class_name = agent_schema.get("class", None) + if class_name is None: + raise ValueError("Error class_name field should not be empty in agent_schema param during IAgent.build() call.") + elif class_name == "LLMSession": + return LLMSession(agent_schema=agent_schema) + elif class_name == "ChainOfThoughtAgent": + return ChainOfThoughtAgent(agent_schema=agent_schema) + +class IAgent: + + # @staticmethod + # def build(agent_schema: dict): + # import importlib + # try: + # package_name = agent_schema.get("package", None) + # class_name = agent_schema.get("class", None) + # + # if package_name is None: + # raise ValueError("Error package_name field should not be empty in agent_schema param during IAgent.build() call.") + # if class_name is None: + # raise ValueError("Error class_name field should not be empty in agent_schema param during IAgent.build() call.") + # # Dynamically import the package + # package = importlib.import_module(package_name) + # # Get the class from the package + # class_ = getattr(package, class_name) + # # Create an instance of the class + # instance = class_(agent_schema=agent_schema) + # + # return instance + # except (ImportError, AttributeError) as e: + # print(f"Error creating instance of {class_name} from {package_name}: {e}") + # return None + + + def stream(self, input: str | dict | list[dict]) -> AsyncGenerator[Any, None]: + raise NotImplementedError("You need to impelement stream_add_all() first!") + + def query(self, input: str | dict | list[dict]) -> str: + raise NotImplementedError("You need to impelement query_add_all() first!") \ No newline at end of file diff --git a/just_agents/llm_session.py b/just_agents/llm_session.py index 83169f4..b56824e 100644 --- a/just_agents/llm_session.py +++ b/just_agents/llm_session.py @@ -7,6 +7,7 @@ from litellm import ModelResponse, completion from litellm.utils import Choices +from just_agents.interfaces.IAgent import IAgent from just_agents.llm_options import LLAMA3 from just_agents.memory import Memory from dataclasses import dataclass, field @@ -14,8 +15,7 @@ from just_agents.streaming.abstract_streaming import AbstractStreaming from just_agents.streaming.openai_streaming import AsyncSession # from just_agents.utils import rotate_completion -from just_agents.interfaces.IAddAllMessages import IAddAllMessages -from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt +from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools from just_agents.rotate_keys import RotateKeys OnCompletion = Callable[[ModelResponse], None] @@ -35,7 +35,7 @@ -class LLMSession(): +class LLMSession(IAgent): available_tools: dict[str, Callable] = dict() memory: Memory = Memory() streaming: AbstractStreaming = None @@ -44,6 +44,7 @@ class LLMSession(): def __init__(self, llm_options: dict[str, Any] = None, + system_prompt:str = None, agent_schema: str | Path | dict | None = None, tools: list[Callable] = None): @@ -52,14 +53,16 @@ def __init__(self, llm_options: dict[str, Any] = None, if self.agent_schema.get(KEY_LIST_PATH, None) is not None: self.key_getter = RotateKeys(self.agent_schema[KEY_LIST_PATH]) self.tools: list[Callable] = tools + if self.tools is None: + self.tools = resolve_tools(self.agent_schema) if self.llm_options is not None: self.llm_options = copy.deepcopy(self.llm_options) #just a satefy requirement to avoid shared dictionaries if (self.key_getter is not None) and (self.llm_options.get("api_key", None) is not None): print("Warning api_key will be rewriten by key_getter. Both are present in llm_options.") - - system_prompt = resolve_system_prompt(self.agent_schema) + if system_prompt is None: + system_prompt = resolve_system_prompt(self.agent_schema) if system_prompt is not None: self.instruct(system_prompt) diff --git a/just_agents/utils.py b/just_agents/utils.py index 0367027..6085ffb 100644 --- a/just_agents/utils.py +++ b/just_agents/utils.py @@ -6,6 +6,8 @@ from typing import Optional, Dict, Any import importlib.resources as resources from dotenv import load_dotenv +import importlib +from typing import Callable from litellm import Message, ModelResponse, completion import copy @@ -69,6 +71,26 @@ def resolve_system_prompt(agent_schema: dict): return system_prompt +def resolve_tools(agent_schema: dict) -> list[Callable]: + function_list:list[Callable] = [] + tools = agent_schema.get('tools', None) + if tools is None: + return None + for entry in tools: + package_name: str = entry['package'] + function_name: str = entry['function'] + try: + # Dynamically import the package + package = importlib.import_module(package_name) + # Get the function from the package + func = getattr(package, function_name) + function_list.append(func) + except (ImportError, AttributeError) as e: + print(f"Error importing {function_name} from {package_name}: {e}") + + return function_list + + def rotate_env_keys() -> str: load_dotenv() keys = []