Skip to content

Commit

Permalink
Added chain of thought agent with function calling support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Karmazin committed Sep 29, 2024
1 parent d821158 commit 3137585
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 65 deletions.
75 changes: 75 additions & 0 deletions just_agents/cot_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from just_agents.llm_session import LLMSession
import json
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol

FINAL_PROMPT = "Please provide the final answer based solely on your reasoning above."
DEFAULT_SYSTEM_PROMPT = """You are an expert AI assistant that explains your reasoning step by step.
For each step, provide a title that describes what you're doing in that step, along with the content.
Decide if you need another step or if you're ready to give the final answer.
Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys.
Make sure you send only one JSON step object.
USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3.
BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO.
IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS.
CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE.
FULLY TEST ALL OTHER POSSIBILITIES.
YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO.
DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.
Example of a valid JSON response:
```json
{
"title": "Identifying Key Information",
"content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
"next_action": "continue"
}```
"""

class ChainOfThoughtAgent():

def __init__(self, llm_options, tools = None, system_prompt:str = DEFAULT_SYSTEM_PROMPT, output_streaming:AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.session: LLMSession = LLMSession(llm_options=llm_options, tools=tools)
if system_prompt is not None:
self.session.instruct(system_prompt)
self.output_streaming: AbstractStreamingProtocol = output_streaming


def stream(self, prompt, max_steps: int = 25, thought_max_tokes:int = 300, final_max_tokens:int = 1200, final_prompt:str = FINAL_PROMPT):
self.session.update_options("max_tokens", thought_max_tokes)
self.session.update_options("response_format", {"type": "json_object"})
step_data = json.loads(self.session.query(prompt))
content = step_data['content'] + "\n"
yield self.output_streaming.get_chunk(0, content, self.session.llm_options)
for step_count in range(1, max_steps):
step_data = json.loads(self.session.proceed())
content = step_data['content'] + "\n"
yield self.output_streaming.get_chunk(step_count, content, self.session.llm_options)
if step_data['next_action'] == 'final_answer':
break

self.session.update_options("max_tokens", final_max_tokens)
final_data = json.loads(self.session.query(final_prompt))
yield self.output_streaming.get_chunk(step_count + 1, final_data['content'], self.session.llm_options)
yield self.output_streaming.done()


def query(self, prompt, max_steps: int = 25, thought_max_tokes:int = 300, final_max_tokens:int = 1200, final_prompt:str = FINAL_PROMPT):
self.session.update_options("max_tokens", thought_max_tokes)
self.session.update_options("response_format", {"type": "json_object"})
step_data = json.loads(self.session.query(prompt))
content = step_data['content'] + "\n"
thoughts:str = content
for step_count in range(1, max_steps):
step_data = json.loads(self.session.proceed())
content = step_data['content'] + "\n"
thoughts += content
if step_data['next_action'] == 'final_answer':
break

self.session.update_options("max_tokens", final_max_tokens)
final_data = json.loads(self.session.query(final_prompt))
return final_data['content'], thoughts

def last_message(self):
return self.session.memory.last_message
5 changes: 5 additions & 0 deletions just_agents/llm_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
"temperature": 0.0
}

OPENAI_GPT4oMINI: Dict = {
"model": "gpt-4o-mini",
"temperature": 0.0
}


LLAMA3: Dict = {
"model": "groq/llama3-groq-70b-8192-tool-use-preview",
Expand Down
20 changes: 12 additions & 8 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@dataclass(kw_only=True)
class LLMSession:
llm_options: dict[str, Any] = field(default_factory=lambda: LLAMA3)
tools: list[Callable] = field(default_factory=list)
tools: list[Callable] = None
available_tools: dict[str, Callable] = field(default_factory=lambda: {})

on_response: list[OnCompletion] = field(default_factory=list)
Expand Down Expand Up @@ -78,6 +78,10 @@ def instruct(self, prompt: str):
self.memory.add_message(system_instruction, True)
return system_instruction


def update_options(self, key:str, value:Any):
self.llm_options[key] = value

def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
"""
Query large language model
Expand All @@ -89,17 +93,17 @@ def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]

question = {"role": "user", "content": prompt}
self.memory.add_message(question, run_callbacks)
return self._query(run_callbacks, output)
return self.proceed(run_callbacks, output)


def query_add_all(self, messages: list[dict], run_callbacks: bool = True, output: Optional[Path] = None) -> str:
self.memory.add_messages(messages, run_callbacks)
return self._query(run_callbacks, output)
return self.proceed(run_callbacks, output)


def stream_all(self, messages: list, run_callbacks: bool = True): # -> ContentStream:
self.memory.add_messages(messages, run_callbacks)
return self._stream()
return self.proceed_stream()

async def stream_async(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> list[Any]:
"""temporary function that allows testing the stream function which Alex wrote but I do not fully understand"""
Expand All @@ -122,7 +126,7 @@ def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]
self.memory.add_message(question, run_callbacks)

# Start the streaming process
content_stream = self._stream()
content_stream = self.proceed_stream()

# If output file is provided, write the stream to the file
if output is not None:
Expand All @@ -143,11 +147,11 @@ def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path]



def _stream(self) -> AsyncGenerator[Any, None]: # -> ContentStream:
def proceed_stream(self) -> AsyncGenerator[Any, None]: # -> ContentStream:
return self.streaming.resp_async_generator(self.memory, self.llm_options, self.available_tools)


def _query(self, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def proceed(self, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
response: ModelResponse = rotate_completion(messages=self.memory.messages, stream=False, options=self.llm_options)
self._process_response(response)
response = self._process_function_calls(response)
Expand Down Expand Up @@ -182,7 +186,7 @@ def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResp
function_to_call = self.available_tools[function_name]
function_args = json.loads(tool_call.function.arguments)
try:
function_response = function_to_call(**function_args)
function_response = str(function_to_call(**function_args))
except Exception as e:
function_response = str(e)
result = {"role":"tool", "content":function_response, "name":function_name, "tool_call_id":tool_call.id}
Expand Down
15 changes: 3 additions & 12 deletions just_agents/streaming/abstract_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Callable, AsyncGenerator, Optional
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol

from just_agents.memory import Memory

Expand All @@ -26,7 +27,7 @@ class AbstractStreaming(ABC):
"""
Class that is required to implement the streaming logic
"""

output_streaming: AbstractStreamingProtocol

@abstractmethod
async def resp_async_generator(
Expand Down Expand Up @@ -60,14 +61,4 @@ def _get_tool_call_message(self, parsers: list[FunctionParser]) -> dict:
for parser in parsers:
tool_calls.append({"type":"function",
"id": parser.id, "function": {"name": parser.name, "arguments": parser.arguments}})
return {"role":"assistant", "content":None, "tool_calls":tool_calls}

def _get_chunk(self, i: int, delta: str, options: Dict):
chunk = {
"id": i,
"object": "chat.completion.chunk",
"created": time.time(),
"model": options["model"],
"choices": [{"delta": {"content": delta}}],
}
return json.dumps(chunk)
return {"role":"assistant", "content":None, "tool_calls":tool_calls}
65 changes: 24 additions & 41 deletions just_agents/streaming/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,46 @@
from typing import Callable, Optional
from just_agents.memory import Memory
from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
from just_agents.utils import rotate_completion
import time
import json

class ChainOfThought(AbstractStreaming):

def __init__(self, output_streaming: AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.output_streaming = output_streaming

async def resp_async_generator(self, memory: Memory,
options: dict,
available_tools: dict[str, Callable]
) -> AsyncGenerator[str, None]:
memory.add_message({"role": "assistant",
"content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
)
step_count = 0
print("This method depricated use cot_agent instead.")

max_steps = 25

while True:
step_count += 1
step_data = self.make_api_call(memory.messages, options, max_tokens = 300)
opt = options.copy()
opt["max_tokens"] = 300
opt["response_format"] = {"type": "json_object"}
for step_count in range(1, max_steps):
response = rotate_completion(messages=memory.messages, stream=False, options=opt)
step_data = json.loads(response.choices[0].message.content)
memory.add_message({"role": "assistant", "content": json.dumps(step_data)})
# Yield after each step for Streamlit to update
print(step_count, " ", step_data['content'])
print(step_count, " ", step_data)
content = step_data['content'] + "\n"
yield f"data: {self._get_chunk(step_count, content, options)}\n\n"
yield self.output_streaming.get_chunk(step_count, content, opt)
if step_data[
'next_action'] == 'final_answer' or step_count > 25: # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
'next_action'] == 'final_answer': # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
break

# Generate final answer
memory.add_message({"role": "user",
"content": "Please provide the final answer based solely on your reasoning above. Do not use JSON formatting. Only provide the text response without any titles or preambles. Retain any formatting as instructed by the original prompt, such as exact formatting for free response or multiple choice."})
final_data = self.make_api_call(memory.messages, options, max_tokens = 1200, is_final_answer=True)
"content": "Please provide the final answer based solely on your reasoning above."})

opt["max_tokens"] = 1200
response = rotate_completion(messages=memory.messages, stream=False, options=opt)
final_data = json.loads(response.choices[0].message.content)
# yield steps, total_thinking_time
print("Final: ", final_data)
yield f"data: {self._get_chunk(step_count + 1, final_data, options)}\n\n"
yield "data: [DONE]\n\n"


def make_api_call(self, messages, options:dict, max_tokens, is_final_answer=False):
opt = options.copy()
for attempt in range(3):
try:
if is_final_answer:
opt["max_tokens"] = max_tokens
response = rotate_completion(messages=messages, stream=False, options=opt, max_tries = 1)
return response.choices[0].message.content
else:
opt["max_tokens"] = max_tokens
opt["response_format"] = {"type": "json_object"}
response = rotate_completion(messages=messages, stream=False, options=opt, max_tries = 1)
return json.loads(response.choices[0].message.content)
except Exception as e:
if attempt == 2:
if is_final_answer:
return {"title": "Error",
"content": f"Failed to generate final answer after 3 attempts. Error: {str(e)}"}
else:
return {"title": "Error",
"content": f"Failed to generate step after 3 attempts. Error: {str(e)}",
"next_action": "final_answer"}
time.sleep(1) # Wait for 1 second before retrying
yield self.output_streaming.get_chunk(step_count + 1, final_data['content'], opt)
yield self.output_streaming.done()
9 changes: 7 additions & 2 deletions just_agents/streaming/openai_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
from typing import Callable, Optional
from just_agents.memory import Memory
from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
from just_agents.utils import rotate_completion
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol


class AsyncSession(AbstractStreaming):

def __init__(self, output_streaming: AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.output_streaming = output_streaming

async def resp_async_generator(self, memory: Memory,
options: dict,
available_tools: dict[str, Callable]
Expand All @@ -25,7 +30,7 @@ async def resp_async_generator(self, memory: Memory,
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {self._get_chunk(i, delta, options)}\n\n"
yield self.output_streaming.get_chunk(i, delta, options)

tool_calls = part["choices"][0]["delta"].get("tool_calls")
if tool_calls and (available_tools is not None):
Expand All @@ -45,4 +50,4 @@ async def resp_async_generator(self, memory: Memory,
if len(deltas) > 0:
memory.add_message({"role":"assistant", "content":"".join(deltas)})

yield "data: [DONE]\n\n"
yield self.output_streaming.done()
9 changes: 9 additions & 0 deletions just_agents/streaming/protocols/abstract_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class AbstractStreamingProtocol:

def get_chunk(self, index:int, delta:str, options:dict):
raise NotImplementedError()
return ""

def done(self):
raise NotImplementedError()
return ""
20 changes: 20 additions & 0 deletions just_agents/streaming/protocols/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
import json
import time


class OpenaiStreamingProtocol(AbstractStreamingProtocol):

def get_chunk(self, index: int, delta: str, options: dict):
chunk = {
"id": index,
"object": "chat.completion.chunk",
"created": time.time(),
"model": options["model"],
"choices": [{"delta": {"content": delta}}],
}
return f"data: {json.dumps(chunk)}\n\n"


def done(self):
return "data: [DONE]\n\n"
9 changes: 7 additions & 2 deletions just_agents/streaming/qwen2_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Callable, Optional
from just_agents.memory import Memory
from just_agents.streaming.abstract_streaming import AbstractStreaming, FunctionParser
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol
from just_agents.utils import rotate_completion
import json
from qwen_agent.llm import get_chat_model
Expand All @@ -11,6 +13,9 @@

class Qwen2AsyncSession(AbstractStreaming):

def __init__(self, output_streaming: AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.output_streaming = output_streaming

def _process_function(self, name: str, arguments: str, available_tools: dict[str, Callable]):
function_args = json.loads(arguments)
function_to_call = available_tools[name]
Expand Down Expand Up @@ -50,7 +55,7 @@ async def resp_async_generator(self, memory: Memory,
if len(content) > 0:
delta = content[prev_len:]
prev_len = len(content)
yield f"data: {self._get_chunk(i, delta, options)}\n\n"
yield self.output_streaming.get_chunk(i, delta, options)

fncall_msgs = [rsp for rsp in messages if rsp.get('function_call', None)]
memory.add_messages(messages)
Expand All @@ -60,4 +65,4 @@ async def resp_async_generator(self, memory: Memory,
for msg in fncall_msgs:
function_call = msg['function_call']
memory.add_message(self._process_function(function_call["name"], function_call['arguments'], available_tools))
yield "data: [DONE]\n\n"
yield self.output_streaming.done()
Loading

0 comments on commit 3137585

Please sign in to comment.