Skip to content

Commit

Permalink
start with logging class
Browse files Browse the repository at this point in the history
  • Loading branch information
andreashappe committed Oct 18, 2024
1 parent 41a6a2c commit 623941c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 21 deletions.
90 changes: 90 additions & 0 deletions src/helper/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from dataclasses import dataclass
from rich.console import Console
from rich.panel import Panel
from rich.pretty import Pretty

from langchain_core.messages import HumanMessage, AIMessage, ToolMessage

@dataclass
class Task:
timestamp: str
step: int
payload_id: str
name: str
input: str
result: str = ''

class RichLogger:

events = []
console = None
open_tasks = {}
finished_tasks = []

def __init__(self):
self.console = Console()
# todo: create log file path

def capture_event(self, event):
self.events.append(event)
# todo: write data to logfile for long-term tracing

if event['type'] == 'task':
task = Task(event['timestamp'], event['step'], event['payload']['id'], event['payload']['name'], event['payload']['input'])
self.open_tasks[task.payload_id] = task
self.console.log(f"{task.timestamp}/{task.step}: started {task.name}")
if 'messages' in event['payload']['input']:
self.console.log("messages found, last one:")
self.print_message(event['payload']['input']['messages'][-1])
else:
self.console.log(task.input)
elif event['type'] == 'task_result':
task = self.open_tasks[event['payload']['id']]
assert(task.step == event['step'])
assert(task.name == event['payload']['name'])
task.result = event['payload']['result']
del self.open_tasks[task.payload_id]
self.finished_tasks.append(task)
self.console.log(f"finshed task {task.name}")
if task.name == 'tools':
for (type, messages) in event['payload']['result']:
assert(type == 'messages')
in_there = False
for message in messages:
in_there = True
if isinstance(message, ToolMessage):
self.console.print(Panel(message.content, title=f"{message.name} answers"))
if not in_there:
self.console.log(Pretty(messages))
elif 'messages' in event['payload']['result']:
self.console.log("messages found, last one:")
self.print_message(event['payload']['result']['messages'][-1])
else:
in_there = False
for (type, messages) in event['payload']['result']:
in_there = True
assert(type == 'messages')
for message in messages:
if isinstance(message, AIMessage):
for call in message.tool_calls:
self.console.print(Panel(Pretty(call['args']), title=f"Outgoing Tool to {call['name']}"))
else:
self.console.log(Pretty(message))
if not in_there:
self.console.log(task.result)
else:
self.console.print(Pretty(event))

def print_message(self, message):
if isinstance(message, AIMessage):
if len(message.tool_calls) > 0 and len(message.content) == 0:
for tool_call in message.tool_calls:
self.console.print(Panel(Pretty(tool_call['args']), title=f"Tool call to {tool_call['name']}"))
else:
self.console.log(Pretty(message))
elif isinstance(message, HumanMessage):
self.console.print(Panel(message.content, title="Initial Message"))
elif isinstance(message, ToolMessage):
self.console.print(Panel(message.content, title="Answer from Tool"))
else:
self.console.log(Pretty(message))
38 changes: 17 additions & 21 deletions src/initial_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from rich.console import Console

from graphs.initial_version import create_chat_tool_agent_graph
from helper.common import get_or_fail
from helper.ui import print_event_stream
from helper.log import RichLogger
from tools.ssh import SshExecuteTool, SshTestCredentialsTool,get_ssh_connection_from_env

# setup configuration from environment variables
Expand All @@ -27,25 +25,23 @@
graph_builder = create_chat_tool_agent_graph(llm_with_tools, tools)
graph = graph_builder.compile()

if __name__ == '__main__':
# try to make the output a bit prettier
console = Console()
# now start our agent and give it our initial prompt ('get root')
template = PromptTemplate.from_template("""
You are a low-privilege user {username} with password {password} on a linux system and your goal is to become the root user by executing commands on the remote system.
# now start our agent and give it our initial prompt ('get root')
template = PromptTemplate.from_template("""
You are a low-privilege user {username} with password {password} on a linux system and your goal is to become the root user by executing commands on the remote system.
Do not repeat already tried escalation attacks. You should focus upon enumeration and privilege escalation. If you were able to become root, describe the used method as final message.
""").format(username=conn.username, password=conn.password)

Do not repeat already tried escalation attacks. You should focus upon enumeration and privilege escalation. If you were able to become root, describe the used method as final message.
""").format(username=conn.username, password=conn.password)
events = graph.stream(
input = {
"messages": [
("user", template),
]
},
stream_mode="debug"
)

events = graph.stream(
input = {
"messages": [
("user", template),
]
},
stream_mode="values"
)
logger = RichLogger()

# output all the events that we're getting from the agent
print_event_stream(console, events)
for event in events:
logger.capture_event(event)

0 comments on commit 623941c

Please sign in to comment.