Skip to content

Commit

Permalink
Merge branch 'main' of github.com:longevity-genie/just-agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Kulaga committed Jun 8, 2024
2 parents f8e73be + 5a2be2b commit 7e46976
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 8 deletions.
134 changes: 127 additions & 7 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from litellm import ModelResponse, completion, Message
from litellm.utils import ChatCompletionMessageToolCall, Function
from litellm import ModelResponse, completion, acompletion, Message
from typing import Any, Dict, List, Optional, Callable
import litellm
import json
Expand All @@ -9,9 +10,101 @@

from just_agents.llm_options import LLAMA3
from just_agents.memory import Memory
from starlette.responses import ContentStream
import time

OnCompletion = Callable[[ModelResponse], None]

class FunctionParser:
id:str = ""
name:str = ""
arguments:str = ""

def __init__(self, id:str):
self.id = id

def parsed(self, name:str, arguments:str):
if name:
self.name += name
if arguments:
self.arguments += arguments
if len(self.name) > 0 and len(self.arguments) > 0 and self.arguments.endswith("}"):
return True
return False


def get_chunk(i:int, delta:str, options: Dict):
chunk = {
"id": i,
"object": "chat.completion.chunk",
"created": time.time(),
"model": options["model"],
"choices": [{"delta": {"content": delta}}],
}
return json.dumps(chunk)


def process_function(parser:FunctionParser, available_tools: Dict[str, Callable]):
function_args = json.loads(parser.arguments)
function_to_call = available_tools[parser.name]
try:
function_response = function_to_call(**function_args)
except Exception as e:
function_response = str(e)
message = Message(role="tool", content=function_response, name=parser.name,
tool_call_id=parser.id) # TODO need to track arguemnts , arguments=function_args
return message


def get_tool_call_message(parsers:list[FunctionParser]) -> Message:
tool_calls = []
for parser in parsers:
tool_calls.append({"type":"function",
"id":parser.id, "function":{"name":parser.name, "arguments":parser.arguments}})
return Message(role="assistant", content=None, tool_calls=tool_calls)


async def _resp_async_generator(memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
response: ModelResponse = completion(messages=memory.messages, stream=True, **options)
parser:FunctionParser = None
function_response = None
tool_calls_message = None
tool_messages:list[Message] = []
parsers:list[FunctionParser] = []
deltas:list[str] = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {get_chunk(i, delta, options)}\n\n"

tool_calls = part["choices"][0]["delta"].get("tool_calls")
if tool_calls and (available_tools is not None):
if not parser:
parser = FunctionParser(id = tool_calls[0].id)
if parser.parsed(tool_calls[0].function.name, tool_calls[0].function.arguments):
tool_messages.append(process_function(parser, available_tools))
parsers.append(parser)
parser = None

if len(tool_messages) > 0:
memory.add_message(get_tool_call_message(parsers))
for message in tool_messages:
memory.add_message(message)
response = completion(messages=memory.messages, stream=True, **options)
deltas = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {get_chunk(i, delta, options)}\n\n"
memory.add_message(Message(role="assistant", content="".join(deltas)))
elif len(deltas) > 0:
memory.add_message(Message(role="assistant", content="".join(deltas)))

yield "data: [DONE]\n\n"


@dataclass(kw_only=True)
class LLMSession:
llm_options: Dict[str, Any] = field(default_factory=lambda: LLAMA3)
Expand Down Expand Up @@ -51,26 +144,52 @@ def instruct(self, prompt: str):
self.memory.add_message(system_instruction, True)
return system_instruction

def query(self, prompt: str, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
"""
Query large language model
:param prompt:
:param stream:
:param run_callbacks:
:param output:
:return:
"""

question = Message(role="user", content=prompt)
self.memory.add_message(question, run_callbacks)
return self._query(run_callbacks, output)


def query_all(self, messages: list, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
self.memory.add_messages(messages, run_callbacks)
return self._query(run_callbacks, output)


def stream_all(self, messages: list, run_callbacks: bool = True) -> ContentStream:
self.memory.add_messages(messages, run_callbacks)
return self._stream()


def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> ContentStream:
question = Message(role="user", content=prompt)
self.memory.add_message(question)
self.memory.add_message(question, run_callbacks)
return self._stream()


def _stream(self) -> ContentStream:
return _resp_async_generator(self.memory, self.llm_options, self.available_tools)


def _query(self, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
options: Dict = self.llm_options
response: ModelResponse = completion(messages=self.memory.messages, stream=stream, **options)
response: ModelResponse = completion(messages=self.memory.messages, stream=False, **options)
self._process_response(response)
executed_response = self._process_function_calls(response)
if executed_response is not None:
response = executed_response
self._process_response(response)
answer = self.message_from_response(response)
self.memory.add_message(answer, run_callbacks)
result: str = self.memory.last_message.content if self.memory.last_message is not None and self.memory.last_message.content is not None else str(self.memory.last_message)
result: str = self.memory.last_message.content if self.memory.last_message is not None and self.memory.last_message.content is not None else str(
self.memory.last_message)
if output is not None:
output.write_text(result)
return result
Expand All @@ -79,7 +198,7 @@ def query(self, prompt: str, stream: bool = False, run_callbacks: bool = True, o
def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResponse]:
"""
processes function calls in the response
:param response:
:param response_message:
:return:
"""
response_message = response.choices[0].message
Expand All @@ -89,6 +208,7 @@ def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResp
message = self.message_from_response(response)
self.memory.add_message(message)
for tool_call in tool_calls:
print(f"Calling function {function_name}({function_args})")
function_name = tool_call.function.name
function_to_call = self.available_tools[function_name]
function_args = json.loads(tool_call.function.arguments)
Expand Down
11 changes: 10 additions & 1 deletion just_agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,13 @@ def add_message(self, message: Message, run_callbacks: bool = True):

@property
def last_message(self) -> Optional[Message]:
return self.messages[-1] if len(self.messages) > 0 else None
return self.messages[-1] if len(self.messages) > 0 else None


def add_messages(self, messages: list, run_callbacks: bool = True):
for message in messages:
msg = Message(content=message["content"], role=message["role"])
self.messages.append(msg)
if run_callbacks:
for handler in self.on_message:
handler(message)

0 comments on commit 7e46976

Please sign in to comment.