-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathnotebook_withdefender.py
122 lines (104 loc) · 3.87 KB
/
notebook_withdefender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: py:percent,ipynb
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# %%
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Attacker agent benchmark comparison in presence of a basic defender
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# %%
import sys
import os
import logging
import gymnasium as gym
import importlib
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
from cyberbattle._env.defender import ScanAndReimageCompromisedMachines
from cyberbattle._env.cyberbattle_env import AttackerGoal, DefenderConstraint, CyberBattleEnv
importlib.reload(learner)
importlib.reload(p)
importlib.reload(p)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %matplotlib inline
# %% {"tags": ["parameters"]}
iteration_count = 600
training_episode_count = 10
plots_dir = "output/plots"
# %%
gym_env = gym.make(
"CyberBattleChain-v0",
size=10,
attacker_goal=AttackerGoal(own_atleast=0, own_atleast_percent=1.0),
defender_constraint=DefenderConstraint(maintain_sla=0.80),
defender_agent=ScanAndReimageCompromisedMachines(probability=0.6, scan_capacity=2, scan_frequency=5),
).unwrapped
cyberbattlechain_defender = gym_env.unwrapped
assert isinstance(cyberbattlechain_defender, CyberBattleEnv)
ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=22, maximum_node_count=22, identifiers=cyberbattlechain_defender.identifiers)
# %%
dqn_with_defender = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain_defender,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.15, replay_memory_size=10000, target_update=5, batch_size=256, learning_rate=0.01),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=5000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL",
)
# %%
dql_exploit_run = learner.epsilon_greedy_search(
cyberbattlechain_defender,
ep,
learner=dqn_with_defender["learner"],
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.0, # 0.35,
render=False,
# render_last_episode_rewards_to='images/chain10',
verbosity=Verbosity.Quiet,
title="Exploiting DQL",
)
# %%
credlookup_run = learner.epsilon_greedy_search(
cyberbattlechain_defender,
ep,
learner=rca.CredentialCacheExploiter(),
episode_count=10,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=10000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="Credential lookups (ϵ-greedy)",
)
# %%
# Plots
all_runs = [credlookup_run, dqn_with_defender, dql_exploit_run]
p.plot_averaged_cummulative_rewards(all_runs=all_runs, title=f"Attacker agents vs Basic Defender -- rewards\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}", save_at=os.path.join(plots_dir, "withdefender-cumreward.png"))
# p.plot_episodes_length(all_runs)
p.plot_averaged_availability(title=f"Attacker agents vs Basic Defender -- availability\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}", all_runs=all_runs, show=False)