-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
416 lines (317 loc) · 12.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
# ----------------------------------------------------------------------
# Main Automated Investigator File containing the web interface
# and most of the logic.
#
# Copyright 2024 Chanakan Moongthin <me@chanakancloud.net>
# on behalf of Up Up Up All Night (Team of Cynclair Hackathon 2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------
# datetime parsing
from datetime import datetime, timedelta
import datetime as dt
# Parsing JSON
import json
# To write JSON output temporarily to file
import tempfile
from typing import Literal, cast
# For loading API Keys from the env
from dotenv import load_dotenv
# LLMs
from langchain_openai import ChatOpenAI
from langchain_core.tools import create_retriever_tool, tool
from langchain_core.messages import (
AIMessage,
)
## LLMs RAGs
from langchain_chroma import Chroma
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain_community.document_loaders import JSONLoader
from langchain_text_splitters import RecursiveJsonSplitter
# Web UI
import mesop as me
import mesop.labs as mel
# Dataclasses
from dataclasses import field
# TI Lookup
from cyntelligence import IPEnrich, MITRESearch, QRadarSearch
from cyntelligence import FileAnalyze
# CACHING
from functools import cache
# NOTE: preferrably provide all useful tables and columns name within Ariel Database to allow for a more accurate query
tool_system = """
You are a tool-calling LLM that will help with cybersecurity, you are working in SOC,
you will utilize all tools you had to help the user when they asked so.
Software stack used are as follow:
- IBM QRadar: Main SIEM
- Swimlane: Playbook
# Ariel DB Information:
## Tables
- events
- flows
- assets_data
- offenses
- asset_properties
- network_hierarchy
## Table: events
# Columns
- starttime / endtime
- sourceip / destinationip
- sourceport / destinationport
- eventname / category
- magnitude
- credibility
- severity
- username
- devicetype
- qid (QRadar ID)
If you think you dont need to call any tools, or there are already enough context,
use the tool "direct_response" to send the information to another LLMs for analysis.
When dealing with epoch timestamp, you must use `convert_timestamp_to_datetime_utc7` tool to convert the timestamp to human readable format of UTC+7.
You can use the tool "retrieval_tool" to actually get the context from chroma retriever if you think you have already fetched the information.
Provide an argument as the string of ip, hash, etc or natural language to the tool "retrieval_tool" to get the context from the database,
include platform name in the query such as "<IP_ADDRESS> abuseipdb" if you want to get the context for that specific platform.
If there is a past request with tool response of "<ADDED_TO_RETRIEVER>", then you can use the tool "retrieval_tool" to get the context from the database directly.
"""
chat_system = """
You are a chat LLM that will help with cybersecurity, you are working in SOC,
you will be taking Tool responses from the tool-calling LLMs (which will be in the context as System Message)
and interpret them nicely to respond to the user according to their question.
Software stack used are as follow:
- IBM QRadar: Main SIEM
- Swimlane: Playbook
You will not mention those stacks unless mentioned by the user, these are for your own information.
You will use markdown to format. You will always respond in Thai.
Presume that the tool responses are always correct and factual, ignore any duplicates information and return what you have.
"""
@me.stateclass
class State:
tool_messages: list[dict] = field(
default_factory=lambda: [{"role": "system", "content": tool_system}]
)
chat_messages: list[dict] = field(
default_factory=lambda: [{"role": "system", "content": chat_system}]
)
@cache
def get_chroma():
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
return Chroma(
collection_name="investigation-context", embedding_function=embedding_function
)
@cache
def pre_init():
db = get_chroma()
retriever = db.as_retriever()
retrieval_tool = create_retriever_tool(
retriever,
"investigation_context",
"Context for the investigation that came from tools, use it to answer the user's question",
)
splitter = RecursiveJsonSplitter()
return (db, retrieval_tool, splitter)
db, retrieval_tool, splitter = pre_init()
# @st.cache_resource
# def get_info_vt_hash_request(file_hash: str) -> vt.Object:
def get_info_vt_hash_request(file_hash: str) -> str:
return ""
# vt_client = vt.Client(os.environ['VIRUSTOTAL_API_KEY'])
# info = vt_client.get_object('/files/{}'.format(file_hash))
# vt_client.close()
# return info
@tool
def convert_timestamp_to_datetime_utc7(timestamp: float) -> str:
"""Convert an epoch timestamp to UTC+7
Args:
timestamp: The epoch timestamp to convert to UTC+7
"""
utc_datetime = datetime.fromtimestamp(timestamp, dt.UTC)
# Define the UTC+7 offset
utc7_offset = timedelta(hours=7)
# Apply the offset to get the datetime in UTC+7
utc7_datetime = utc_datetime + utc7_offset
return str(utc7_datetime)
@tool
def direct_response(res: str) -> str:
"""Send the response information in res argument to analysis LLMs
Args:
res: The response information to be sent to the analysis LLMs
"""
return res
@tool
def execute_aql(aql: str) -> str:
"""Interact with QRadar SIEM using Ariel Query Language for interacting with logs and alerts
Args:
aql: The Ariel Query Language statement to execute to the server and returns query result. MUST be a valid AQL statement
"""
print(aql)
qradar = QRadarSearch(aql)
return str(qradar.search())
@tool
def get_info_mitre(
technique_id_list: list[str],
stix_type: Literal[
"attack-pattern",
"malware",
"tool",
"intrusion-set",
"campaign",
"course-of-action",
"x-mitre-matrix",
"x-mitre-tactic",
"x-mitre-data-source",
"x-mitre-data-component",
],
) -> str:
"""Get information about a specific MITRE ATT&CK technique
Args:
technique_id_list: A list of MITRE ATT&CK technique IDs to get information about
stix_type: The STIX type to get information about, must be one of attack-pattern, malware, tool, intrusion-set, campaign, course-of-action, x-mitre-matrix, x-mitre-tactic, x-mitre-data-source, x-mitre-data-component
"""
mitre = MITRESearch(technique_id_list)
return str(mitre.get_object_by_attack_ids(stix_type))
@tool
def get_info_tip(targets: list[str], type: str) -> str:
"""Interact with Threat Intelligence Platforms for getting information related to IP addresses, file hashes, domains, urls
Args:
targets: A list of ip addresses, file hashes, domains, urls to be look up on Threat Intelligence Platform
type: The type of the target, must be one of ip, hash, domain, url
"""
new_targets = []
print("GETTING TIP")
# prevent duplication in the db
for target in targets:
results = db.similarity_search(target, k=1)
if not results:
new_targets.append(target)
if not new_targets:
return "<ADDED_TO_RETRIEVER>"
match type:
case "ip":
ip_enrich = IPEnrich(new_targets)
info = ip_enrich.get_all_info()
case "hash":
file_analyze = FileAnalyze(new_targets)
info = file_analyze.get_all_info()
case _:
return f"Invalid type: {type}"
with tempfile.NamedTemporaryFile(mode="w", delete=True) as f:
docs = splitter.split_json(json_data=info, convert_lists=True)
# temp file save and load via jsonloader
f.write(json.dumps(docs))
loader = JSONLoader(f.name, jq_schema=".[]", text_content=False)
docs = loader.load()
db.add_documents(docs)
f.close()
return "<ADDED_TO_RETRIEVER>"
tools = [
retrieval_tool,
execute_aql,
direct_response,
convert_timestamp_to_datetime_utc7,
get_info_tip,
get_info_mitre,
]
def init():
load_dotenv() # Load API Key from env (OpenTyphoon API Key, Not actually OpenAI)
# BEGIN TOOL CALLING LLMs
tool_llm = ChatOpenAI(
model="typhoon-v1.5-instruct-fc",
temperature=0,
base_url="https://api.opentyphoon.ai/v1",
) # function calling LLMs specifically for interacting with tools
llm_with_tools = tool_llm.bind_tools(tools)
# END TOOL CALLING LLMs
# BEGIN CHAT LLMs
chat_llm = ChatOpenAI(
model="typhoon-v1.5x-70b-instruct",
temperature=0.7,
base_url="https://api.opentyphoon.ai/v1",
streaming=True,
) # Utilize a smarter LLMs for analysis and chat
# END CHAT LLMs
return (llm_with_tools, chat_llm)
# START INFERENCE
tool_llm, chat_llm = init()
# Util function to deduplicate any context
def deduplicate_system_role(messages):
seen_content = set()
result = []
for d in messages:
if d.get("role") == "system":
content = d.get("content")
if content == "<ADDED_TO_RETRIEVER>":
continue
if content not in seen_content:
seen_content.add(content)
result.append(d)
return result
# UI Setup
def on_load(e: me.LoadEvent):
me.set_theme_mode("system")
@me.page(path="/", title="Chat With SOC", on_load=on_load)
def page():
mel.chat(transform, title="Chat With SOC", bot_user="Automated Investigator")
def process_tool_calls(tool_calls, state, ai_msg, tool_llm):
if not tool_calls:
return
print("AI MSG:", ai_msg.content)
tool_call = tool_calls[0]
selected_tool = {
"retrieval_tool": retrieval_tool,
"execute_aql": execute_aql,
"direct_response": direct_response,
"convert_timestamp_to_datetime_utc7": convert_timestamp_to_datetime_utc7,
"get_info_tip": get_info_tip,
"get_info_mitre": get_info_mitre,
}[tool_call["name"].lower()]
tool_output = selected_tool.invoke(tool_call["args"])
if "<ADDED_TO_RETRIEVER>" in tool_output:
state.tool_messages.append(
{
"role": "user",
"content": 'Use the tool "retrieval_tool" to get the context from the database.',
}
)
# Invoke the tool LLMs again to get the context from the database
ai_msg = tool_llm.invoke(state.tool_messages)
state.tool_messages.append(ai_msg.dict())
# Recursive call to process remaining tool calls
process_tool_calls(ai_msg.tool_calls, state, ai_msg, tool_llm)
else:
state.tool_messages.append(
{"role": "tool", "content": tool_output, "tool_call_id": tool_call["id"]}
)
state.chat_messages.append(
{"role": "system", "content": tool_output}
) # Add Tool Responses to chat messages so that chat LLMs have the responses state
def transform(input: str, history: list[mel.ChatMessage]):
state = me.state(State)
# update the state with the new input
state.tool_messages.append({"role": "user", "content": input})
state.chat_messages.append({"role": "user", "content": input})
# Start by calling tool-calling LLMs for gathering informations or doing actions
ai_msg = cast(AIMessage, tool_llm.invoke(state.tool_messages))
state.tool_messages.append(ai_msg.dict())
print("Tool LLM Response:", ai_msg)
process_tool_calls(ai_msg.tool_calls, state, ai_msg, tool_llm)
full_chat = ""
for chunk in chat_llm.stream(state.chat_messages):
full_chat += str(chunk.content)
yield chunk.content
state.chat_messages.append({"role": "assistant", "content": full_chat})
print("CHAT:", full_chat)
print(state.chat_messages)
state.chat_messages = deduplicate_system_role(state.chat_messages)
state.tool_messages = deduplicate_system_role(state.tool_messages)