Skip to content

Commit

Permalink
final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Karmazin committed Oct 16, 2024
1 parent 483668b commit 188a5c5
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 20 deletions.
4 changes: 4 additions & 0 deletions just_agents/config/cot_agent_prompt.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
class: "ChainOfThoughtAgent"
system_prompt: "You are an expert AI assistant that explains your reasoning step by step.
For each step, provide a title that describes what you're doing in that step, along with the content.
Decide if you need another step or if you're ready to give the final answer.
Expand Down Expand Up @@ -29,6 +30,9 @@ action_final: "final_answer"
thought_max_tokes: 300
max_steps: 25
final_max_tokens: 1200
tools:
# - package:
# function:
llm_session:
just_streaming_method: "openai"
completion_remove_key_on_error: True
Expand Down
6 changes: 5 additions & 1 deletion just_agents/config/llm_session_schema.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
class: "LLMSession"
just_streaming_method: "openai"
system_prompt_path:
system_prompt:
completion_remove_key_on_error: True
completion_max_tries: 2
backup_options:
key_list_path:
key_list_path:
tools:
# - package:
# function:
13 changes: 6 additions & 7 deletions just_agents/cot_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from just_agents.interfaces.IAddAllMessages import IAddAllMessages
from just_agents.interfaces.IAgent import IAgent
from just_agents.llm_session import LLMSession
import json
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
from pathlib import Path, PurePath
import yaml
from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt
from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools

# schema parameters:
LLM_SESSION = "llm_session"
Expand All @@ -17,19 +17,18 @@
FINAL_MAX_TOKENS = "final_max_tokens"
FINAL_PROMPT = "final_prompt"

#

class ChainOfThoughtAgent():
class ChainOfThoughtAgent(IAgent):

def __init__(self, llm_options: dict = None, agent_schema: str | Path | dict | None = None,
tools: list = None, output_streaming:AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.agent_schema: dict = resolve_agent_schema(agent_schema, "ChainOfThoughtAgent", "cot_agent_prompt.yaml")
if tools is None:
tools = resolve_tools(self.agent_schema)
self.session: LLMSession = LLMSession(llm_options=resolve_llm_options(self.agent_schema, llm_options),
system_prompt=resolve_system_prompt(self.agent_schema),
agent_schema=self.agent_schema.get(LLM_SESSION, None), tools=tools)

system_prompt = resolve_system_prompt(self.agent_schema)
if system_prompt is not None:
self.session.instruct(system_prompt)
self.output_streaming: AbstractStreamingProtocol = output_streaming


Expand Down
7 changes: 0 additions & 7 deletions just_agents/interfaces/IAddAllMessages.py

This file was deleted.

44 changes: 44 additions & 0 deletions just_agents/interfaces/IAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import AsyncGenerator, Any

def build_agent(agent_schema: dict):
from just_agents.cot_agent import ChainOfThoughtAgent
from just_agents.llm_session import LLMSession
class_name = agent_schema.get("class", None)
if class_name is None:
raise ValueError("Error class_name field should not be empty in agent_schema param during IAgent.build() call.")
elif class_name == "LLMSession":
return LLMSession(agent_schema=agent_schema)
elif class_name == "ChainOfThoughtAgent":
return ChainOfThoughtAgent(agent_schema=agent_schema)

class IAgent:

# @staticmethod
# def build(agent_schema: dict):
# import importlib
# try:
# package_name = agent_schema.get("package", None)
# class_name = agent_schema.get("class", None)
#
# if package_name is None:
# raise ValueError("Error package_name field should not be empty in agent_schema param during IAgent.build() call.")
# if class_name is None:
# raise ValueError("Error class_name field should not be empty in agent_schema param during IAgent.build() call.")
# # Dynamically import the package
# package = importlib.import_module(package_name)
# # Get the class from the package
# class_ = getattr(package, class_name)
# # Create an instance of the class
# instance = class_(agent_schema=agent_schema)
#
# return instance
# except (ImportError, AttributeError) as e:
# print(f"Error creating instance of {class_name} from {package_name}: {e}")
# return None


def stream(self, input: str | dict | list[dict]) -> AsyncGenerator[Any, None]:
raise NotImplementedError("You need to impelement stream_add_all() first!")

def query(self, input: str | dict | list[dict]) -> str:
raise NotImplementedError("You need to impelement query_add_all() first!")
13 changes: 8 additions & 5 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from litellm import ModelResponse, completion
from litellm.utils import Choices

from just_agents.interfaces.IAgent import IAgent
from just_agents.llm_options import LLAMA3
from just_agents.memory import Memory
from dataclasses import dataclass, field
from typing import Callable, Optional
from just_agents.streaming.abstract_streaming import AbstractStreaming
from just_agents.streaming.openai_streaming import AsyncSession
# from just_agents.utils import rotate_completion
from just_agents.interfaces.IAddAllMessages import IAddAllMessages
from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt
from just_agents.utils import resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools
from just_agents.rotate_keys import RotateKeys

OnCompletion = Callable[[ModelResponse], None]
Expand All @@ -35,7 +35,7 @@



class LLMSession():
class LLMSession(IAgent):
available_tools: dict[str, Callable] = dict()
memory: Memory = Memory()
streaming: AbstractStreaming = None
Expand All @@ -44,6 +44,7 @@ class LLMSession():


def __init__(self, llm_options: dict[str, Any] = None,
system_prompt:str = None,
agent_schema: str | Path | dict | None = None,
tools: list[Callable] = None):

Expand All @@ -52,14 +53,16 @@ def __init__(self, llm_options: dict[str, Any] = None,
if self.agent_schema.get(KEY_LIST_PATH, None) is not None:
self.key_getter = RotateKeys(self.agent_schema[KEY_LIST_PATH])
self.tools: list[Callable] = tools
if self.tools is None:
self.tools = resolve_tools(self.agent_schema)

if self.llm_options is not None:
self.llm_options = copy.deepcopy(self.llm_options) #just a satefy requirement to avoid shared dictionaries
if (self.key_getter is not None) and (self.llm_options.get("api_key", None) is not None):
print("Warning api_key will be rewriten by key_getter. Both are present in llm_options.")


system_prompt = resolve_system_prompt(self.agent_schema)
if system_prompt is None:
system_prompt = resolve_system_prompt(self.agent_schema)
if system_prompt is not None:
self.instruct(system_prompt)

Expand Down
22 changes: 22 additions & 0 deletions just_agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Optional, Dict, Any
import importlib.resources as resources
from dotenv import load_dotenv
import importlib
from typing import Callable
from litellm import Message, ModelResponse, completion
import copy

Expand Down Expand Up @@ -69,6 +71,26 @@ def resolve_system_prompt(agent_schema: dict):
return system_prompt


def resolve_tools(agent_schema: dict) -> list[Callable]:
function_list:list[Callable] = []
tools = agent_schema.get('tools', None)
if tools is None:
return None
for entry in tools:
package_name: str = entry['package']
function_name: str = entry['function']
try:
# Dynamically import the package
package = importlib.import_module(package_name)
# Get the function from the package
func = getattr(package, function_name)
function_list.append(func)
except (ImportError, AttributeError) as e:
print(f"Error importing {function_name} from {package_name}: {e}")

return function_list


def rotate_env_keys() -> str:
load_dotenv()
keys = []
Expand Down

0 comments on commit 188a5c5

Please sign in to comment.