Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
William Blum committed Aug 6, 2024
1 parent 27c9b53 commit c22f480
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 104 deletions.
6 changes: 3 additions & 3 deletions cyberbattle/_env/cyberbattle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import numpy

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter # type: ignore
from plotly.subplots import make_subplots # type: ignore

from cyberbattle._env.defender import DefenderAgent
from cyberbattle.simulation.model import PortName, PrivilegeLevel
Expand Down Expand Up @@ -1385,7 +1385,7 @@ def render_as_fig(self):
# plot the cumulative reward and network side by side using plotly
fig = make_subplots(rows=1, cols=2)
fig.add_trace(
go.Scatter(
Scatter(
y=numpy.array(self.__episode_rewards).cumsum(), name="cumulative reward"
),
row=1,
Expand Down
249 changes: 148 additions & 101 deletions cyberbattle/simulation/commandcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
"""
import networkx as nx
from typing import List, Optional, Dict, Union, Tuple, Set
import plotly.graph_objects as go
from plotly.graph_objects import Scatter, Figure, layout # type: ignore

from . import model, actions


class CommandControl:
""" The Command and Control interface to the simulation.
"""The Command and Control interface to the simulation.
This represents a server that centralize information and secrets
retrieved from the individual clients running on the network nodes.
Expand All @@ -27,16 +27,21 @@ class CommandControl:
__environment: model.Environment
__total_reward: float

def __init__(self, environment_or_actuator: Union[model.Environment, actions.AgentActions]):
def __init__(
self, environment_or_actuator: Union[model.Environment, actions.AgentActions]
):
if isinstance(environment_or_actuator, model.Environment):
self.__environment = environment_or_actuator
self._actuator = actions.AgentActions(self.__environment, throws_on_invalid_actions=True)
self._actuator = actions.AgentActions(
self.__environment, throws_on_invalid_actions=True
)
elif isinstance(environment_or_actuator, actions.AgentActions):
self.__environment = environment_or_actuator._environment
self._actuator = environment_or_actuator
else:
raise ValueError(
"Invalid type: expecting Union[model.Environment, actions.AgentActions])")
"Invalid type: expecting Union[model.Environment, actions.AgentActions])"
)

self.__gathered_credentials = set()
self.__total_reward = 0
Expand All @@ -61,19 +66,24 @@ def list_nodes(self) -> List[actions.DiscoveredNodeInfo]:

def get_node_color(self, node_info: model.NodeInfo) -> str:
if node_info.agent_installed:
return 'red'
return "red"
else:
return 'green'
return "green"

def plot_nodes(self) -> None:
"""Plot the sub-graph of nodes either so far
discovered (their ID is knowned by the agent)
or owned (i.e. where the attacker client is installed)."""
discovered_nodes = [node_id for node_id, _ in self._actuator.discovered_nodes()]
sub_graph = self.__environment.network.subgraph(discovered_nodes)
nx.draw(sub_graph,
with_labels=True,
node_color=[self.get_node_color(self.__environment.get_node(i)) for i in sub_graph.nodes])
nx.draw(
sub_graph,
with_labels=True,
node_color=[
self.get_node_color(self.__environment.get_node(i))
for i in sub_graph.nodes
],
)

def known_vulnerabilities(self) -> model.VulnerabilityLibrary:
"""Return the global list of known vulnerability."""
Expand Down Expand Up @@ -102,62 +112,71 @@ def print_all_attacks(self) -> None:
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
return self._actuator.print_all_attacks()

def run_attack(self,
node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID
) -> Optional[model.VulnerabilityOutcome]:
def run_attack(
self, node_id: model.NodeID, vulnerability_id: model.VulnerabilityID
) -> Optional[model.VulnerabilityOutcome]:
"""Run an attack and attempt to exploit a vulnerability on the specified node."""
result = self._actuator.exploit_local_vulnerability(node_id, vulnerability_id)
if result.outcome is not None:
self.__save_credentials(result.outcome)
self.__accumulate_reward(result.reward)
return result.outcome

def run_remote_attack(self, node_id: model.NodeID,
target_node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID
) -> Optional[model.VulnerabilityOutcome]:
def run_remote_attack(
self,
node_id: model.NodeID,
target_node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID,
) -> Optional[model.VulnerabilityOutcome]:
"""Run a remote attack from the specified node to exploit a remote vulnerability
in the specified target node"""

result = self._actuator.exploit_remote_vulnerability(
node_id, target_node_id, vulnerability_id)
node_id, target_node_id, vulnerability_id
)
if result.outcome is not None:
self.__save_credentials(result.outcome)
self.__accumulate_reward(result.reward)
return result.outcome

def connect_and_infect(self, source_node_id: model.NodeID,
target_node_id: model.NodeID,
port_name: model.PortName,
credentials: model.CredentialID) -> bool:
def connect_and_infect(
self,
source_node_id: model.NodeID,
target_node_id: model.NodeID,
port_name: model.PortName,
credentials: model.CredentialID,
) -> bool:
"""Install the agent on a remote machine using the
provided credentials"""
result = self._actuator.connect_to_remote_machine(source_node_id, target_node_id, port_name,
credentials)
provided credentials"""
result = self._actuator.connect_to_remote_machine(
source_node_id, target_node_id, port_name, credentials
)
self.__accumulate_reward(result.reward)
return result.outcome is not None

@property
def credentials_gathered_so_far(self) -> Set[model.CachedCredential]:
"""Returns the list of credentials gathered so far by the
attacker (from any node)"""
attacker (from any node)"""
return self.__gathered_credentials


def get_outcome_first_credential(outcome: Optional[model.VulnerabilityOutcome]) -> model.CredentialID:
def get_outcome_first_credential(
outcome: Optional[model.VulnerabilityOutcome],
) -> model.CredentialID:
"""Return the first credential found in a given vulnerability exploit outcome"""
if outcome is not None and isinstance(outcome, model.LeakedCredentials):
return outcome.credentials[0].credential
else:
raise ValueError('Vulnerability outcome does not contain any credential')
raise ValueError("Vulnerability outcome does not contain any credential")


class EnvironmentDebugging:
"""Provides debugging feature exposing internals of the environment
that are not normally revealed to an attacker agent according to
the rules of the simulation.
that are not normally revealed to an attacker agent according to
the rules of the simulation.
"""

__environment: model.Environment
__actuator: actions.AgentActions

Expand All @@ -167,11 +186,15 @@ def __init__(self, actuator_or_c2: Union[actions.AgentActions, CommandControl]):
elif isinstance(actuator_or_c2, CommandControl):
self.__actuator = actuator_or_c2._actuator
else:
raise ValueError("Invalid type: expecting Union[actions.AgentActions, CommandControl])")
raise ValueError(
"Invalid type: expecting Union[actions.AgentActions, CommandControl])"
)

self.__environment = self.__actuator._environment

def network_as_plotly_traces(self, xref: str = "x", yref: str = "y") -> Tuple[List[go.Scatter], dict]:
def network_as_plotly_traces(
self, xref: str = "x", yref: str = "y"
) -> Tuple[List[Scatter], dict]:
known_nodes = [node_id for node_id, _ in self.__actuator.discovered_nodes()]

subgraph = self.__environment.network.subgraph(known_nodes)
Expand All @@ -181,95 +204,119 @@ def network_as_plotly_traces(self, xref: str = "x", yref: str = "y") -> Tuple[Li

def edge_text(source: model.NodeID, target: model.NodeID) -> str:
data = self.__environment.network.get_edge_data(source, target)
name: str = data['kind'].name
name: str = data["kind"].name
return name

color_map = {actions.EdgeAnnotation.LATERAL_MOVE: 'red',
actions.EdgeAnnotation.REMOTE_EXPLOIT: 'orange',
actions.EdgeAnnotation.KNOWS: 'gray'}
color_map = {
actions.EdgeAnnotation.LATERAL_MOVE: "red",
actions.EdgeAnnotation.REMOTE_EXPLOIT: "orange",
actions.EdgeAnnotation.KNOWS: "gray",
}

def edge_color(source: model.NodeID, target: model.NodeID) -> str:
data = self.__environment.network.get_edge_data(source, target)
if 'kind' in data:
return color_map[data['kind']]
return 'black'

layout: dict = dict(title="CyberBattle simulation", font=dict(size=10), showlegend=True,
autosize=False, width=800, height=400,
margin=go.layout.Margin(l=2, r=2, b=15, t=35),
hovermode='closest',
annotations=[dict(
ax=pos[source][0],
ay=pos[source][1], axref=xref, ayref=yref,
x=pos[target][0],
y=pos[target][1], xref=xref, yref=yref,
arrowcolor=edge_color(source, target),
hovertext=edge_text(source, target),
showarrow=True,
arrowhead=1,
arrowsize=1,
arrowwidth=1,
startstandoff=10,
standoff=10,
align='center',
opacity=1
) for (source, target) in list(subgraph.edges)]
)

owned_nodes_coordinates = [(i, c) for i, c in pos.items()
if self.get_node_information(i).agent_installed]
discovered_nodes_coordinates = [(i, c)
for i, c in pos.items()
if not self.get_node_information(i).agent_installed]

trace_owned_nodes = go.Scatter(
if "kind" in data:
return color_map[data["kind"]]
return "black"

_layout: dict = dict(
title="CyberBattle simulation",
font=dict(size=10),
showlegend=True,
autosize=False,
width=800,
height=400,
margin=layout.Margin(l=2, r=2, b=15, t=35),
hovermode="closest",
annotations=[
dict(
ax=pos[source][0],
ay=pos[source][1],
axref=xref,
ayref=yref,
x=pos[target][0],
y=pos[target][1],
xref=xref,
yref=yref,
arrowcolor=edge_color(source, target),
hovertext=edge_text(source, target),
showarrow=True,
arrowhead=1,
arrowsize=1,
arrowwidth=1,
startstandoff=10,
standoff=10,
align="center",
opacity=1,
)
for (source, target) in list(subgraph.edges)
],
)

owned_nodes_coordinates = [
(i, c)
for i, c in pos.items()
if self.get_node_information(i).agent_installed
]
discovered_nodes_coordinates = [
(i, c)
for i, c in pos.items()
if not self.get_node_information(i).agent_installed
]

trace_owned_nodes = Scatter(
x=[c[0] for i, c in owned_nodes_coordinates],
y=[c[1] for i, c in owned_nodes_coordinates],
mode='markers+text',
name='owned',
marker=dict(symbol='circle-dot',
size=5,
# green #0e9d00
color='#D32F2E', # red
line=dict(color='rgb(255,0,0)', width=8)
),
mode="markers+text",
name="owned",
marker=dict(
symbol="circle-dot",
size=5,
# green #0e9d00
color="#D32F2E", # red
line=dict(color="rgb(255,0,0)", width=8),
),
text=[i for i, c in owned_nodes_coordinates],
hoverinfo='text',
textposition="bottom center"
hoverinfo="text",
textposition="bottom center",
)

trace_discovered_nodes = go.Scatter(
trace_discovered_nodes = Scatter(
x=[c[0] for i, c in discovered_nodes_coordinates],
y=[c[1] for i, c in discovered_nodes_coordinates],
mode='markers+text',
name='discovered',
marker=dict(symbol='circle-dot',
size=5,
color='#0e9d00', # green
line=dict(color='rgb(0,255,0)', width=8)
),
mode="markers+text",
name="discovered",
marker=dict(
symbol="circle-dot",
size=5,
color="#0e9d00", # green
line=dict(color="rgb(0,255,0)", width=8),
),
text=[i for i, c in discovered_nodes_coordinates],
hoverinfo='text',
textposition="bottom center"
hoverinfo="text",
textposition="bottom center",
)

dummy_scatter_for_edge_legend = [
go.Scatter(
x=[0], y=[0], mode="lines",
line=dict(color=color_map[a]),
name=a.name
) for a in actions.EdgeAnnotation]

all_scatters = dummy_scatter_for_edge_legend + [trace_owned_nodes, trace_discovered_nodes]
return (all_scatters, layout)
Scatter(
x=[0], y=[0], mode="lines", line=dict(color=color_map[a]), name=a.name
)
for a in actions.EdgeAnnotation
]

all_scatters = dummy_scatter_for_edge_legend + [
trace_owned_nodes,
trace_discovered_nodes,
]
return (all_scatters, _layout)

def plot_discovered_network(self) -> None:
"""Plot the network graph with plotly"""
fig = go.Figure()
traces, layout = self.network_as_plotly_traces()
fig = Figure()
traces, _layout = self.network_as_plotly_traces()
for t in traces:
fig.add_trace(t)
fig.update_layout(layout)
fig.update_layout(_layout)
fig.show()

def get_node_information(self, node_id: model.NodeID) -> model.NodeInfo:
Expand Down

0 comments on commit c22f480

Please sign in to comment.