-
Notifications
You must be signed in to change notification settings - Fork 153
/
mitigations.py
127 lines (100 loc) · 4.35 KB
/
mitigations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import requests
from mistralai import Mistral
from openai import OpenAI, AzureOpenAI
import google.generativeai as genai
# Function to create a prompt to generate mitigating controls
def create_mitigations_prompt(threats):
prompt = f"""
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology. Your task is to provide potential mitigations for the threats identified in the threat model. It is very important that your responses are tailored to reflect the details of the threats.
Your output should be in the form of a markdown table with the following columns:
- Column A: Threat Type
- Column B: Scenario
- Column C: Suggested Mitigation(s)
Below is the list of identified threats:
{threats}
YOUR RESPONSE (do not wrap in a code block):
"""
return prompt
# Function to get mitigations from the GPT response.
def get_mitigations(api_key, model_name, prompt):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model = model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides threat mitigation strategies in Markdown format."},
{"role": "user", "content": prompt}
]
)
# Access the content directly as the response will be in text format
mitigations = response.choices[0].message.content
return mitigations
# Function to get mitigations from the Azure OpenAI response.
def get_mitigations_azure(azure_api_endpoint, azure_api_key, azure_api_version, azure_deployment_name, prompt):
client = AzureOpenAI(
azure_endpoint = azure_api_endpoint,
api_key = azure_api_key,
api_version = azure_api_version,
)
response = client.chat.completions.create(
model = azure_deployment_name,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides threat mitigation strategies in Markdown format."},
{"role": "user", "content": prompt}
]
)
# Access the content directly as the response will be in text format
mitigations = response.choices[0].message.content
return mitigations
# Function to get mitigations from the Google model's response.
def get_mitigations_google(google_api_key, google_model, prompt):
genai.configure(api_key=google_api_key)
model = genai.GenerativeModel(
google_model,
system_instruction="You are a helpful assistant that provides threat mitigation strategies in Markdown format.",
)
response = model.generate_content(prompt)
try:
# Extract the text content from the 'candidates' attribute
mitigations = response.candidates[0].content.parts[0].text
# Replace '\n' with actual newline characters
mitigations = mitigations.replace('\\n', '\n')
except (IndexError, AttributeError) as e:
print(f"Error accessing response content: {str(e)}")
print("Raw response:")
print(response)
return None
return mitigations
# Function to get mitigations from the Mistral model's response.
def get_mitigations_mistral(mistral_api_key, mistral_model, prompt):
client = Mistral(api_key=mistral_api_key)
response = client.chat.complete(
model = mistral_model,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides threat mitigation strategies in Markdown format."},
{"role": "user", "content": prompt}
]
)
# Access the content directly as the response will be in text format
mitigations = response.choices[0].message.content
return mitigations
# Function to get mitigations from Ollama hosted LLM.
def get_mitigations_ollama(ollama_model, prompt):
url = "http://localhost:11434/api/chat"
data = {
"model": ollama_model,
"stream": False,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that provides threat mitigation strategies in Markdown format."},
{
"role": "user",
"content": prompt
}
]
}
response = requests.post(url, json=data)
outer_json = response.json()
# Access the 'content' attribute of the 'message' dictionary
mitigations = outer_json["message"]["content"]
return mitigations