Skip to content

Commit

Permalink
Refactor TotalLoss (#441)
Browse files Browse the repository at this point in the history
It contains the lyapunov loss and barrier loss
  • Loading branch information
hongkai-dai authored Apr 9, 2022
1 parent 59cc3c1 commit acc685d
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 84 deletions.
4 changes: 2 additions & 2 deletions neural_network_lyapunov/test/debug_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def compute_total_loss(system, x_equilibrium, relu_layer_width, params_val,
dut.lyapunov_derivative_epsilon = lyapunov_derivative_epsilon
total_loss = dut.total_loss(state_samples, state_samples_next)
if requires_grad:
total_loss[0].backward()
total_loss.loss.backward()
grad = np.concatenate([
p.grad.detach().numpy().reshape((-1, )) for p in relu.parameters()
],
axis=0)
return grad
else:
return total_loss[0].item()
return total_loss.loss.item()


def compute_milp_cost_given_relu(system, x_equilibrium, relu_layer_width,
Expand Down
29 changes: 15 additions & 14 deletions neural_network_lyapunov/test/test_train_lyapunov_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,35 +242,36 @@ def test_total_loss(self):

self.assertEqual(
positivity_state_samples.shape[0] + 1,
total_loss_return.lyap_positivity_state_samples.shape[0])
total_loss_return.lyap_loss.positivity_state_samples.shape[0])
self.assertEqual(
derivative_state_samples.shape[0] + 1,
total_loss_return.lyap_derivative_state_samples.shape[0])
total_loss_return.lyap_loss.derivative_state_samples.shape[0])
self.assertEqual(
derivative_state_samples_next.shape[0] + 1,
total_loss_return.lyap_derivative_state_samples_next.shape[0])
self.assertAlmostEqual(total_loss_return.loss.item(),
(total_loss_return.lyap_positivity_sample_loss +
total_loss_return.lyap_derivative_sample_loss +
total_loss_return.lyap_positivity_mip_loss +
total_loss_return.lyap_derivative_mip_loss +
total_loss_return.gap_mip_loss).item())
total_loss_return.lyap_loss.derivative_state_samples_next.shape[0])
self.assertAlmostEqual(
total_loss_return.loss.item(),
(total_loss_return.lyap_loss.positivity_sample_loss +
total_loss_return.lyap_loss.derivative_sample_loss +
total_loss_return.lyap_loss.positivity_mip_loss +
total_loss_return.lyap_loss.derivative_mip_loss +
total_loss_return.lyap_loss.gap_mip_loss).item())
# Compute hinge(-V(x)) for sampled state x
loss_expected = 0.
loss_expected += dut.lyapunov_positivity_sample_cost_weight *\
lyapunov_hybrid_system.lyapunov_positivity_loss_at_samples(
x_equilibrium,
total_loss_return.lyap_positivity_state_samples[
total_loss_return.lyap_loss.positivity_state_samples[
-dut.max_sample_pool_size:],
V_lambda, dut.lyapunov_positivity_epsilon,
R=R_options.R(), margin=dut.lyapunov_positivity_sample_margin)
loss_expected += dut.lyapunov_derivative_sample_cost_weight *\
lyapunov_hybrid_system.\
lyapunov_derivative_loss_at_samples_and_next_states(
V_lambda, dut.lyapunov_derivative_epsilon,
total_loss_return.lyap_derivative_state_samples[
total_loss_return.lyap_loss.derivative_state_samples[
-dut.max_sample_pool_size:],
total_loss_return.lyap_derivative_state_samples_next[
total_loss_return.lyap_loss.derivative_state_samples_next[
-dut.max_sample_pool_size:],
x_equilibrium, dut.lyapunov_derivative_eps_type,
R=R_options.R(), margin=dut.lyapunov_derivative_sample_margin)
Expand Down Expand Up @@ -319,9 +320,9 @@ def test_total_loss(self):
self.assertAlmostEqual(total_loss_return.loss.item(),
loss_expected.item(),
places=4)
self.assertAlmostEqual(total_loss_return.lyap_positivity_mip_obj,
self.assertAlmostEqual(total_loss_return.lyap_loss.positivity_mip_obj,
lyapunov_positivity_mip.gurobi_model.ObjVal)
self.assertAlmostEqual(total_loss_return.lyap_derivative_mip_obj,
self.assertAlmostEqual(total_loss_return.lyap_loss.derivative_mip_obj,
lyapunov_derivative_mip.gurobi_model.ObjVal)

def solve_boundary_gap_mip_tester(self, dut):
Expand Down
157 changes: 89 additions & 68 deletions neural_network_lyapunov/train_lyapunov_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def solve_barrier_derivative_mip(self):
self.barrier_x_star, self.barrier_c, self.barrier_epsilon)
if self.barrier_derivative_mip_pool_solutions > 1:
barrier_deriv_return.milp.gurobi_model.setParam(
gurobipy.GRB.Param.PoolSearchMode, 2)
gurobipy.GRB.Param.PoolSearchMode, 2)
barrier_deriv_return.milp.gurobi_model.setParam(
gurobipy.GRB.Param.PoolSolutions,
self.barrier_derivative_mip_pool_solutions)
Expand All @@ -587,28 +587,41 @@ def solve_barrier_derivative_mip(self):
return barrier_deriv_return.milp, barrier_deriv_mip_obj,\
mip_adversarial

class LyapLoss:
def __init__(self):
self.positivity_mip_obj = None
self.derivative_mip_obj = None
self.positivity_sample_loss = None
self.derivative_sample_loss = None
self.positivity_mip_loss = None
self.derivative_mip_loss = None
self.gap_mip_loss = None
self.positivity_state_samples = None
self.derivative_state_samples = None
self.derivative_state_samples_next = None

class BarrierLoss:
def __init__(self):
self.safe_mip_obj = None
self.unsafe_mip_obj = None
self.derivative_mip_obj = None
self.safe_sample_loss = None
self.unsafe_sample_loss = None
self.derivative_sample_loss = None
self.safe_mip_loss = None
self.unsafe_mip_loss = None
self.derivative_mip_loss = None
self.safe_state_samples = None
self.unsafe_state_samples = None
self.derivative_state_samples = None

class TotalLossReturn:
def __init__(self, loss: torch.Tensor, lyap_positivity_mip_obj: float,
lyap_derivative_mip_obj: float,
lyap_positivity_sample_loss: torch.Tensor,
lyap_derivative_sample_loss: torch.Tensor,
lyap_positivity_mip_loss,
lyap_derivative_mip_loss: torch.Tensor,
gap_mip_loss: torch.Tensor, lyap_positivity_state_samples,
lyap_derivative_state_samples,
lyap_derivative_state_samples_next):
def __init__(self, loss: torch.Tensor, lyap_loss, barrier_loss):
self.loss = loss
self.lyap_positivity_mip_obj = lyap_positivity_mip_obj
self.lyap_derivative_mip_obj = lyap_derivative_mip_obj
self.lyap_positivity_sample_loss = lyap_positivity_sample_loss
self.lyap_derivative_sample_loss = lyap_derivative_sample_loss
self.lyap_positivity_mip_loss = lyap_positivity_mip_loss
self.lyap_derivative_mip_loss = lyap_derivative_mip_loss
self.gap_mip_loss = gap_mip_loss
self.lyap_positivity_state_samples = lyap_positivity_state_samples
self.lyap_derivative_state_samples = lyap_derivative_state_samples
self.lyap_derivative_state_samples_next = \
lyap_derivative_state_samples_next
assert (isinstance(lyap_loss, Trainer.LyapLoss))
self.lyap_loss = lyap_loss
assert (isinstance(barrier_loss, Trainer.BarrierLoss))
self.barrier_loss = barrier_loss

def total_loss(self, positivity_state_samples,
lyap_derivative_state_samples,
Expand Down Expand Up @@ -648,100 +661,107 @@ def total_loss(self, positivity_state_samples,
value_gap_mip_loss is boundary_value_gap_mip_cost_weight * cost5
"""
dtype = self.lyapunov_hybrid_system.system.dtype
lyap_loss = Trainer.LyapLoss()
barrier_loss = Trainer.BarrierLoss()
if lyap_positivity_mip_cost_weight is not None:
lyap_positivity_mip, lyap_positivity_mip_obj,\
lyap_positivity_mip, lyap_loss.positivity_mip_obj,\
positivity_mip_adversarial = self.solve_positivity_mip()
else:
lyap_positivity_mip = None
lyap_positivity_mip_obj = None
lyap_loss.positivity_mip_obj = None
positivity_mip_adversarial = None
if lyap_derivative_mip_cost_weight is not None:
lyap_derivative_mip, lyap_derivative_mip_obj,\
lyap_derivative_mip, lyap_loss.derivative_mip_obj,\
lyap_derivative_mip_adversarial,\
lyap_derivative_mip_adversarial_next =\
self.solve_lyap_derivative_mip()
else:
lyap_derivative_mip = None
lyap_derivative_mip_obj = None
lyap_loss.derivative_mip_obj = None
lyap_derivative_mip_adversarial = None
lyap_derivative_mip_adversarial_next = None

loss = torch.tensor(0., dtype=dtype)

positivity_mip_loss = 0.
lyap_loss.positivity_mip_loss = torch.tensor(0., dtype=dtype)
if lyap_positivity_mip_cost_weight != 0 and\
lyap_positivity_mip_cost_weight is not None:
positivity_mip_loss = \
lyap_loss.positivity_mip_loss = \
lyap_positivity_mip_cost_weight * \
lyap_positivity_mip.\
compute_objective_from_mip_data_and_solution(
solution_number=0, penalty=1e-13)
lyap_derivative_mip_loss = 0
lyap_loss.derivative_mip_loss = torch.tensor(0, dtype=dtype)
if lyap_derivative_mip_cost_weight != 0\
and lyap_derivative_mip_cost_weight is not None:
mip_cost = lyap_derivative_mip.\
compute_objective_from_mip_data_and_solution(
solution_number=0, penalty=1e-13)
lyap_derivative_mip_loss = \
lyap_loss.derivative_mip_loss = \
lyap_derivative_mip_cost_weight * mip_cost
gap_mip_loss = 0
lyap_loss.gap_mip_loss = 0
if boundary_value_gap_mip_cost_weight != 0:
boundary_value_gap, V_min_milp, V_max_milp, x_min, x_max = \
self.solve_boundary_gap_mip()
print(f"boundary_value_gap: {V_max_milp - V_min_milp}")
gap_mip_loss = \
lyap_loss.gap_mip_loss = \
boundary_value_gap_mip_cost_weight * boundary_value_gap

# We add the most adverisal states of the positivity MIP and derivative
# MIP to the training set. Note we solve positivity MIP and derivative
# MIP separately. This is different from adding the most adversarial
# state of the total loss.
lyap_loss.positivity_state_samples = positivity_state_samples
lyap_loss.derivative_state_samples = lyap_derivative_state_samples
lyap_loss.derivative_state_samples_next = \
lyap_derivative_state_samples_next
if self.add_positivity_adversarial_state and \
lyap_positivity_mip_cost_weight is not None:
positivity_state_samples = torch.cat(
lyap_loss.positivity_state_samples = torch.cat(
(positivity_state_samples, positivity_mip_adversarial), dim=0)
if self.add_derivative_adversarial_state and \
lyap_derivative_mip_cost_weight is not None:
lyap_derivative_state_samples = torch.cat(
lyap_loss.derivative_state_samples = torch.cat(
(lyap_derivative_state_samples,
lyap_derivative_mip_adversarial),
dim=0)
lyap_derivative_state_samples_next = torch.cat(
lyap_loss.derivative_state_samples_next = torch.cat(
(lyap_derivative_state_samples_next,
lyap_derivative_mip_adversarial_next))

if positivity_state_samples.shape[0] > self.max_sample_pool_size:
if lyap_loss.positivity_state_samples.shape[
0] > self.max_sample_pool_size:
positivity_state_samples_in_pool = \
positivity_state_samples[-self.max_sample_pool_size:]
lyap_loss.positivity_state_samples[-self.max_sample_pool_size:]
else:
positivity_state_samples_in_pool = positivity_state_samples
if lyap_derivative_state_samples.shape[0] > self.max_sample_pool_size:
positivity_state_samples_in_pool = \
lyap_loss.positivity_state_samples
if lyap_loss.derivative_state_samples.shape[
0] > self.max_sample_pool_size:
lyap_derivative_state_samples_in_pool = \
lyap_derivative_state_samples[-self.max_sample_pool_size:]
lyap_loss.derivative_state_samples[-self.max_sample_pool_size:]
lyap_derivative_state_samples_next_in_pool = \
lyap_derivative_state_samples_next[-self.max_sample_pool_size:]
lyap_loss.derivative_state_samples_next[
-self.max_sample_pool_size:]
else:
lyap_derivative_state_samples_in_pool = \
lyap_derivative_state_samples
lyap_loss.derivative_state_samples
lyap_derivative_state_samples_next_in_pool = \
lyap_derivative_state_samples_next
positivity_sample_loss, lyap_derivative_sample_loss = \
lyap_loss.derivative_state_samples_next
lyap_loss.positivity_sample_loss, lyap_loss.derivative_sample_loss = \
self.lyapunov_sample_loss(
positivity_state_samples_in_pool,
lyap_derivative_state_samples_in_pool,
lyap_derivative_state_samples_next_in_pool,
lyap_positivity_sample_cost_weight,
lyap_derivative_sample_cost_weight)

loss = positivity_sample_loss + lyap_derivative_sample_loss + \
positivity_mip_loss + lyap_derivative_mip_loss + gap_mip_loss
loss = lyap_loss.positivity_sample_loss + \
lyap_loss.derivative_sample_loss + \
lyap_loss.positivity_mip_loss + lyap_loss.derivative_mip_loss +\
lyap_loss.gap_mip_loss

return Trainer.TotalLossReturn(
loss, lyap_positivity_mip_obj, lyap_derivative_mip_obj,
positivity_sample_loss, lyap_derivative_sample_loss,
positivity_mip_loss, lyap_derivative_mip_loss, gap_mip_loss,
positivity_state_samples, lyap_derivative_state_samples,
lyap_derivative_state_samples_next)
return Trainer.TotalLossReturn(loss, lyap_loss, barrier_loss)

def _save_network(self, iter_count):
if self.save_network_path:
Expand Down Expand Up @@ -843,50 +863,51 @@ def train(self, state_samples_all):
self.lyapunov_derivative_mip_cost_weight,
self.boundary_value_gap_mip_cost_weight)
positivity_state_samples = \
total_loss_return.lyap_positivity_state_samples
total_loss_return.lyap_loss.positivity_state_samples
derivative_state_samples = \
total_loss_return.lyap_derivative_state_samples
total_loss_return.lyap_loss.derivative_state_samples
derivative_state_samples_next = \
total_loss_return.lyap_derivative_state_samples_next
total_loss_return.lyap_loss.derivative_state_samples_next

if self.enable_wandb:
wandb.log({
"loss": total_loss_return.loss.item(),
"positivity MIP cost":
total_loss_return.lyap_positivity_mip_obj,
total_loss_return.lyap_loss.positivity_mip_obj,
"derivative MIP cost":
total_loss_return.lyap_derivative_mip_obj,
"boundary gap MIP": total_loss_return.gap_mip_loss,
total_loss_return.lyap_loss.derivative_mip_obj,
"boundary gap MIP":
total_loss_return.lyap_loss.gap_mip_loss,
"time": time.time() - train_start_time
})
if self.output_flag:
print(f"Iter {iter_count}, loss {total_loss_return.loss}, " +
"positivity cost " +
f"{total_loss_return.lyap_positivity_mip_obj}, " +
f"{total_loss_return.lyap_loss.positivity_mip_obj}, " +
"derivative_cost " +
f"{total_loss_return.lyap_derivative_mip_obj}")
if total_loss_return.lyap_positivity_mip_obj <=\
f"{total_loss_return.lyap_loss.derivative_mip_obj}")
if total_loss_return.lyap_loss.positivity_mip_obj <=\
self.lyapunov_positivity_convergence_tol and\
total_loss_return.lyap_derivative_mip_obj <= \
total_loss_return.lyap_loss.derivative_mip_obj <= \
self.lyapunov_derivative_convergence_tol:
return (True, total_loss_return.loss.item(),
total_loss_return.lyap_positivity_mip_obj,
total_loss_return.lyap_derivative_mip_obj)
if total_loss_return.lyap_positivity_mip_obj < \
total_loss_return.lyap_loss.positivity_mip_obj,
total_loss_return.lyap_loss.derivative_mip_obj)
if total_loss_return.lyap_loss.positivity_mip_obj < \
self.lyapunov_positivity_convergence_tol and\
total_loss_return.lyap_derivative_mip_obj <\
total_loss_return.lyap_loss.derivative_mip_obj <\
best_derivative_mip_cost:
best_training_params = [ # noqa
p.clone() for p in training_params
]
best_derivative_mip_cost = \
total_loss_return.lyap_derivative_mip_obj
total_loss_return.lyap_loss.derivative_mip_obj
total_loss_return.loss.backward()
optimizer.step()
iter_count += 1
return (False, total_loss_return.loss.item(),
total_loss_return.lyap_positivity_mip_obj,
total_loss_return.lyap_derivative_mip_obj)
total_loss_return.lyap_loss.positivity_mip_obj,
total_loss_return.lyap_loss.derivative_mip_obj)

def train_lyapunov_on_samples(self, state_samples_all, num_epochs,
batch_size):
Expand Down

0 comments on commit acc685d

Please sign in to comment.