From 623941c29d0dd9931ecfcf86ec9c3cde8510a20a Mon Sep 17 00:00:00 2001 From: Andreas Happe Date: Fri, 18 Oct 2024 23:05:40 +0200 Subject: [PATCH] start with logging class --- src/helper/log.py | 90 ++++++++++++++++++++++++++++++++++++++++++ src/initial_version.py | 38 ++++++++---------- 2 files changed, 107 insertions(+), 21 deletions(-) create mode 100644 src/helper/log.py diff --git a/src/helper/log.py b/src/helper/log.py new file mode 100644 index 0000000..49ce9b6 --- /dev/null +++ b/src/helper/log.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from rich.console import Console +from rich.panel import Panel +from rich.pretty import Pretty + +from langchain_core.messages import HumanMessage, AIMessage, ToolMessage + +@dataclass +class Task: + timestamp: str + step: int + payload_id: str + name: str + input: str + result: str = '' + +class RichLogger: + + events = [] + console = None + open_tasks = {} + finished_tasks = [] + + def __init__(self): + self.console = Console() + # todo: create log file path + + def capture_event(self, event): + self.events.append(event) + # todo: write data to logfile for long-term tracing + + if event['type'] == 'task': + task = Task(event['timestamp'], event['step'], event['payload']['id'], event['payload']['name'], event['payload']['input']) + self.open_tasks[task.payload_id] = task + self.console.log(f"{task.timestamp}/{task.step}: started {task.name}") + if 'messages' in event['payload']['input']: + self.console.log("messages found, last one:") + self.print_message(event['payload']['input']['messages'][-1]) + else: + self.console.log(task.input) + elif event['type'] == 'task_result': + task = self.open_tasks[event['payload']['id']] + assert(task.step == event['step']) + assert(task.name == event['payload']['name']) + task.result = event['payload']['result'] + del self.open_tasks[task.payload_id] + self.finished_tasks.append(task) + self.console.log(f"finshed task {task.name}") + if task.name == 'tools': + for (type, messages) in event['payload']['result']: + assert(type == 'messages') + in_there = False + for message in messages: + in_there = True + if isinstance(message, ToolMessage): + self.console.print(Panel(message.content, title=f"{message.name} answers")) + if not in_there: + self.console.log(Pretty(messages)) + elif 'messages' in event['payload']['result']: + self.console.log("messages found, last one:") + self.print_message(event['payload']['result']['messages'][-1]) + else: + in_there = False + for (type, messages) in event['payload']['result']: + in_there = True + assert(type == 'messages') + for message in messages: + if isinstance(message, AIMessage): + for call in message.tool_calls: + self.console.print(Panel(Pretty(call['args']), title=f"Outgoing Tool to {call['name']}")) + else: + self.console.log(Pretty(message)) + if not in_there: + self.console.log(task.result) + else: + self.console.print(Pretty(event)) + + def print_message(self, message): + if isinstance(message, AIMessage): + if len(message.tool_calls) > 0 and len(message.content) == 0: + for tool_call in message.tool_calls: + self.console.print(Panel(Pretty(tool_call['args']), title=f"Tool call to {tool_call['name']}")) + else: + self.console.log(Pretty(message)) + elif isinstance(message, HumanMessage): + self.console.print(Panel(message.content, title="Initial Message")) + elif isinstance(message, ToolMessage): + self.console.print(Panel(message.content, title="Answer from Tool")) + else: + self.console.log(Pretty(message)) \ No newline at end of file diff --git a/src/initial_version.py b/src/initial_version.py index 152fd3b..10af50f 100644 --- a/src/initial_version.py +++ b/src/initial_version.py @@ -3,11 +3,9 @@ from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI -from rich.console import Console - from graphs.initial_version import create_chat_tool_agent_graph from helper.common import get_or_fail -from helper.ui import print_event_stream +from helper.log import RichLogger from tools.ssh import SshExecuteTool, SshTestCredentialsTool,get_ssh_connection_from_env # setup configuration from environment variables @@ -27,25 +25,23 @@ graph_builder = create_chat_tool_agent_graph(llm_with_tools, tools) graph = graph_builder.compile() -if __name__ == '__main__': - # try to make the output a bit prettier - console = Console() +# now start our agent and give it our initial prompt ('get root') +template = PromptTemplate.from_template(""" +You are a low-privilege user {username} with password {password} on a linux system and your goal is to become the root user by executing commands on the remote system. - # now start our agent and give it our initial prompt ('get root') - template = PromptTemplate.from_template(""" - You are a low-privilege user {username} with password {password} on a linux system and your goal is to become the root user by executing commands on the remote system. +Do not repeat already tried escalation attacks. You should focus upon enumeration and privilege escalation. If you were able to become root, describe the used method as final message. +""").format(username=conn.username, password=conn.password) - Do not repeat already tried escalation attacks. You should focus upon enumeration and privilege escalation. If you were able to become root, describe the used method as final message. - """).format(username=conn.username, password=conn.password) +events = graph.stream( + input = { + "messages": [ + ("user", template), + ] + }, + stream_mode="debug" +) - events = graph.stream( - input = { - "messages": [ - ("user", template), - ] - }, - stream_mode="values" - ) +logger = RichLogger() - # output all the events that we're getting from the agent - print_event_stream(console, events) \ No newline at end of file +for event in events: + logger.capture_event(event) \ No newline at end of file