From acc685d8987f780bcc711ab5d9a192ebf557562b Mon Sep 17 00:00:00 2001 From: Hongkai Dai Date: Sat, 9 Apr 2022 12:30:02 -0700 Subject: [PATCH] Refactor TotalLoss (#441) It contains the lyapunov loss and barrier loss --- .../test/debug_gradient.py | 4 +- .../test/test_train_lyapunov_barrier.py | 29 ++-- .../train_lyapunov_barrier.py | 157 ++++++++++-------- 3 files changed, 106 insertions(+), 84 deletions(-) diff --git a/neural_network_lyapunov/test/debug_gradient.py b/neural_network_lyapunov/test/debug_gradient.py index a9d16e9f..aaf27874 100644 --- a/neural_network_lyapunov/test/debug_gradient.py +++ b/neural_network_lyapunov/test/debug_gradient.py @@ -41,14 +41,14 @@ def compute_total_loss(system, x_equilibrium, relu_layer_width, params_val, dut.lyapunov_derivative_epsilon = lyapunov_derivative_epsilon total_loss = dut.total_loss(state_samples, state_samples_next) if requires_grad: - total_loss[0].backward() + total_loss.loss.backward() grad = np.concatenate([ p.grad.detach().numpy().reshape((-1, )) for p in relu.parameters() ], axis=0) return grad else: - return total_loss[0].item() + return total_loss.loss.item() def compute_milp_cost_given_relu(system, x_equilibrium, relu_layer_width, diff --git a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py index 57f2c5ef..bcc16d12 100644 --- a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py +++ b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py @@ -242,25 +242,26 @@ def test_total_loss(self): self.assertEqual( positivity_state_samples.shape[0] + 1, - total_loss_return.lyap_positivity_state_samples.shape[0]) + total_loss_return.lyap_loss.positivity_state_samples.shape[0]) self.assertEqual( derivative_state_samples.shape[0] + 1, - total_loss_return.lyap_derivative_state_samples.shape[0]) + total_loss_return.lyap_loss.derivative_state_samples.shape[0]) self.assertEqual( derivative_state_samples_next.shape[0] + 1, - total_loss_return.lyap_derivative_state_samples_next.shape[0]) - self.assertAlmostEqual(total_loss_return.loss.item(), - (total_loss_return.lyap_positivity_sample_loss + - total_loss_return.lyap_derivative_sample_loss + - total_loss_return.lyap_positivity_mip_loss + - total_loss_return.lyap_derivative_mip_loss + - total_loss_return.gap_mip_loss).item()) + total_loss_return.lyap_loss.derivative_state_samples_next.shape[0]) + self.assertAlmostEqual( + total_loss_return.loss.item(), + (total_loss_return.lyap_loss.positivity_sample_loss + + total_loss_return.lyap_loss.derivative_sample_loss + + total_loss_return.lyap_loss.positivity_mip_loss + + total_loss_return.lyap_loss.derivative_mip_loss + + total_loss_return.lyap_loss.gap_mip_loss).item()) # Compute hinge(-V(x)) for sampled state x loss_expected = 0. loss_expected += dut.lyapunov_positivity_sample_cost_weight *\ lyapunov_hybrid_system.lyapunov_positivity_loss_at_samples( x_equilibrium, - total_loss_return.lyap_positivity_state_samples[ + total_loss_return.lyap_loss.positivity_state_samples[ -dut.max_sample_pool_size:], V_lambda, dut.lyapunov_positivity_epsilon, R=R_options.R(), margin=dut.lyapunov_positivity_sample_margin) @@ -268,9 +269,9 @@ def test_total_loss(self): lyapunov_hybrid_system.\ lyapunov_derivative_loss_at_samples_and_next_states( V_lambda, dut.lyapunov_derivative_epsilon, - total_loss_return.lyap_derivative_state_samples[ + total_loss_return.lyap_loss.derivative_state_samples[ -dut.max_sample_pool_size:], - total_loss_return.lyap_derivative_state_samples_next[ + total_loss_return.lyap_loss.derivative_state_samples_next[ -dut.max_sample_pool_size:], x_equilibrium, dut.lyapunov_derivative_eps_type, R=R_options.R(), margin=dut.lyapunov_derivative_sample_margin) @@ -319,9 +320,9 @@ def test_total_loss(self): self.assertAlmostEqual(total_loss_return.loss.item(), loss_expected.item(), places=4) - self.assertAlmostEqual(total_loss_return.lyap_positivity_mip_obj, + self.assertAlmostEqual(total_loss_return.lyap_loss.positivity_mip_obj, lyapunov_positivity_mip.gurobi_model.ObjVal) - self.assertAlmostEqual(total_loss_return.lyap_derivative_mip_obj, + self.assertAlmostEqual(total_loss_return.lyap_loss.derivative_mip_obj, lyapunov_derivative_mip.gurobi_model.ObjVal) def solve_boundary_gap_mip_tester(self, dut): diff --git a/neural_network_lyapunov/train_lyapunov_barrier.py b/neural_network_lyapunov/train_lyapunov_barrier.py index f3dbaf15..8859aec4 100644 --- a/neural_network_lyapunov/train_lyapunov_barrier.py +++ b/neural_network_lyapunov/train_lyapunov_barrier.py @@ -562,7 +562,7 @@ def solve_barrier_derivative_mip(self): self.barrier_x_star, self.barrier_c, self.barrier_epsilon) if self.barrier_derivative_mip_pool_solutions > 1: barrier_deriv_return.milp.gurobi_model.setParam( - gurobipy.GRB.Param.PoolSearchMode, 2) + gurobipy.GRB.Param.PoolSearchMode, 2) barrier_deriv_return.milp.gurobi_model.setParam( gurobipy.GRB.Param.PoolSolutions, self.barrier_derivative_mip_pool_solutions) @@ -587,28 +587,41 @@ def solve_barrier_derivative_mip(self): return barrier_deriv_return.milp, barrier_deriv_mip_obj,\ mip_adversarial + class LyapLoss: + def __init__(self): + self.positivity_mip_obj = None + self.derivative_mip_obj = None + self.positivity_sample_loss = None + self.derivative_sample_loss = None + self.positivity_mip_loss = None + self.derivative_mip_loss = None + self.gap_mip_loss = None + self.positivity_state_samples = None + self.derivative_state_samples = None + self.derivative_state_samples_next = None + + class BarrierLoss: + def __init__(self): + self.safe_mip_obj = None + self.unsafe_mip_obj = None + self.derivative_mip_obj = None + self.safe_sample_loss = None + self.unsafe_sample_loss = None + self.derivative_sample_loss = None + self.safe_mip_loss = None + self.unsafe_mip_loss = None + self.derivative_mip_loss = None + self.safe_state_samples = None + self.unsafe_state_samples = None + self.derivative_state_samples = None + class TotalLossReturn: - def __init__(self, loss: torch.Tensor, lyap_positivity_mip_obj: float, - lyap_derivative_mip_obj: float, - lyap_positivity_sample_loss: torch.Tensor, - lyap_derivative_sample_loss: torch.Tensor, - lyap_positivity_mip_loss, - lyap_derivative_mip_loss: torch.Tensor, - gap_mip_loss: torch.Tensor, lyap_positivity_state_samples, - lyap_derivative_state_samples, - lyap_derivative_state_samples_next): + def __init__(self, loss: torch.Tensor, lyap_loss, barrier_loss): self.loss = loss - self.lyap_positivity_mip_obj = lyap_positivity_mip_obj - self.lyap_derivative_mip_obj = lyap_derivative_mip_obj - self.lyap_positivity_sample_loss = lyap_positivity_sample_loss - self.lyap_derivative_sample_loss = lyap_derivative_sample_loss - self.lyap_positivity_mip_loss = lyap_positivity_mip_loss - self.lyap_derivative_mip_loss = lyap_derivative_mip_loss - self.gap_mip_loss = gap_mip_loss - self.lyap_positivity_state_samples = lyap_positivity_state_samples - self.lyap_derivative_state_samples = lyap_derivative_state_samples - self.lyap_derivative_state_samples_next = \ - lyap_derivative_state_samples_next + assert (isinstance(lyap_loss, Trainer.LyapLoss)) + self.lyap_loss = lyap_loss + assert (isinstance(barrier_loss, Trainer.BarrierLoss)) + self.barrier_loss = barrier_loss def total_loss(self, positivity_state_samples, lyap_derivative_state_samples, @@ -648,84 +661,94 @@ def total_loss(self, positivity_state_samples, value_gap_mip_loss is boundary_value_gap_mip_cost_weight * cost5 """ dtype = self.lyapunov_hybrid_system.system.dtype + lyap_loss = Trainer.LyapLoss() + barrier_loss = Trainer.BarrierLoss() if lyap_positivity_mip_cost_weight is not None: - lyap_positivity_mip, lyap_positivity_mip_obj,\ + lyap_positivity_mip, lyap_loss.positivity_mip_obj,\ positivity_mip_adversarial = self.solve_positivity_mip() else: lyap_positivity_mip = None - lyap_positivity_mip_obj = None + lyap_loss.positivity_mip_obj = None positivity_mip_adversarial = None if lyap_derivative_mip_cost_weight is not None: - lyap_derivative_mip, lyap_derivative_mip_obj,\ + lyap_derivative_mip, lyap_loss.derivative_mip_obj,\ lyap_derivative_mip_adversarial,\ lyap_derivative_mip_adversarial_next =\ self.solve_lyap_derivative_mip() else: lyap_derivative_mip = None - lyap_derivative_mip_obj = None + lyap_loss.derivative_mip_obj = None lyap_derivative_mip_adversarial = None lyap_derivative_mip_adversarial_next = None loss = torch.tensor(0., dtype=dtype) - positivity_mip_loss = 0. + lyap_loss.positivity_mip_loss = torch.tensor(0., dtype=dtype) if lyap_positivity_mip_cost_weight != 0 and\ lyap_positivity_mip_cost_weight is not None: - positivity_mip_loss = \ + lyap_loss.positivity_mip_loss = \ lyap_positivity_mip_cost_weight * \ lyap_positivity_mip.\ compute_objective_from_mip_data_and_solution( solution_number=0, penalty=1e-13) - lyap_derivative_mip_loss = 0 + lyap_loss.derivative_mip_loss = torch.tensor(0, dtype=dtype) if lyap_derivative_mip_cost_weight != 0\ and lyap_derivative_mip_cost_weight is not None: mip_cost = lyap_derivative_mip.\ compute_objective_from_mip_data_and_solution( solution_number=0, penalty=1e-13) - lyap_derivative_mip_loss = \ + lyap_loss.derivative_mip_loss = \ lyap_derivative_mip_cost_weight * mip_cost - gap_mip_loss = 0 + lyap_loss.gap_mip_loss = 0 if boundary_value_gap_mip_cost_weight != 0: boundary_value_gap, V_min_milp, V_max_milp, x_min, x_max = \ self.solve_boundary_gap_mip() print(f"boundary_value_gap: {V_max_milp - V_min_milp}") - gap_mip_loss = \ + lyap_loss.gap_mip_loss = \ boundary_value_gap_mip_cost_weight * boundary_value_gap # We add the most adverisal states of the positivity MIP and derivative # MIP to the training set. Note we solve positivity MIP and derivative # MIP separately. This is different from adding the most adversarial # state of the total loss. + lyap_loss.positivity_state_samples = positivity_state_samples + lyap_loss.derivative_state_samples = lyap_derivative_state_samples + lyap_loss.derivative_state_samples_next = \ + lyap_derivative_state_samples_next if self.add_positivity_adversarial_state and \ lyap_positivity_mip_cost_weight is not None: - positivity_state_samples = torch.cat( + lyap_loss.positivity_state_samples = torch.cat( (positivity_state_samples, positivity_mip_adversarial), dim=0) if self.add_derivative_adversarial_state and \ lyap_derivative_mip_cost_weight is not None: - lyap_derivative_state_samples = torch.cat( + lyap_loss.derivative_state_samples = torch.cat( (lyap_derivative_state_samples, lyap_derivative_mip_adversarial), dim=0) - lyap_derivative_state_samples_next = torch.cat( + lyap_loss.derivative_state_samples_next = torch.cat( (lyap_derivative_state_samples_next, lyap_derivative_mip_adversarial_next)) - if positivity_state_samples.shape[0] > self.max_sample_pool_size: + if lyap_loss.positivity_state_samples.shape[ + 0] > self.max_sample_pool_size: positivity_state_samples_in_pool = \ - positivity_state_samples[-self.max_sample_pool_size:] + lyap_loss.positivity_state_samples[-self.max_sample_pool_size:] else: - positivity_state_samples_in_pool = positivity_state_samples - if lyap_derivative_state_samples.shape[0] > self.max_sample_pool_size: + positivity_state_samples_in_pool = \ + lyap_loss.positivity_state_samples + if lyap_loss.derivative_state_samples.shape[ + 0] > self.max_sample_pool_size: lyap_derivative_state_samples_in_pool = \ - lyap_derivative_state_samples[-self.max_sample_pool_size:] + lyap_loss.derivative_state_samples[-self.max_sample_pool_size:] lyap_derivative_state_samples_next_in_pool = \ - lyap_derivative_state_samples_next[-self.max_sample_pool_size:] + lyap_loss.derivative_state_samples_next[ + -self.max_sample_pool_size:] else: lyap_derivative_state_samples_in_pool = \ - lyap_derivative_state_samples + lyap_loss.derivative_state_samples lyap_derivative_state_samples_next_in_pool = \ - lyap_derivative_state_samples_next - positivity_sample_loss, lyap_derivative_sample_loss = \ + lyap_loss.derivative_state_samples_next + lyap_loss.positivity_sample_loss, lyap_loss.derivative_sample_loss = \ self.lyapunov_sample_loss( positivity_state_samples_in_pool, lyap_derivative_state_samples_in_pool, @@ -733,15 +756,12 @@ def total_loss(self, positivity_state_samples, lyap_positivity_sample_cost_weight, lyap_derivative_sample_cost_weight) - loss = positivity_sample_loss + lyap_derivative_sample_loss + \ - positivity_mip_loss + lyap_derivative_mip_loss + gap_mip_loss + loss = lyap_loss.positivity_sample_loss + \ + lyap_loss.derivative_sample_loss + \ + lyap_loss.positivity_mip_loss + lyap_loss.derivative_mip_loss +\ + lyap_loss.gap_mip_loss - return Trainer.TotalLossReturn( - loss, lyap_positivity_mip_obj, lyap_derivative_mip_obj, - positivity_sample_loss, lyap_derivative_sample_loss, - positivity_mip_loss, lyap_derivative_mip_loss, gap_mip_loss, - positivity_state_samples, lyap_derivative_state_samples, - lyap_derivative_state_samples_next) + return Trainer.TotalLossReturn(loss, lyap_loss, barrier_loss) def _save_network(self, iter_count): if self.save_network_path: @@ -843,50 +863,51 @@ def train(self, state_samples_all): self.lyapunov_derivative_mip_cost_weight, self.boundary_value_gap_mip_cost_weight) positivity_state_samples = \ - total_loss_return.lyap_positivity_state_samples + total_loss_return.lyap_loss.positivity_state_samples derivative_state_samples = \ - total_loss_return.lyap_derivative_state_samples + total_loss_return.lyap_loss.derivative_state_samples derivative_state_samples_next = \ - total_loss_return.lyap_derivative_state_samples_next + total_loss_return.lyap_loss.derivative_state_samples_next if self.enable_wandb: wandb.log({ "loss": total_loss_return.loss.item(), "positivity MIP cost": - total_loss_return.lyap_positivity_mip_obj, + total_loss_return.lyap_loss.positivity_mip_obj, "derivative MIP cost": - total_loss_return.lyap_derivative_mip_obj, - "boundary gap MIP": total_loss_return.gap_mip_loss, + total_loss_return.lyap_loss.derivative_mip_obj, + "boundary gap MIP": + total_loss_return.lyap_loss.gap_mip_loss, "time": time.time() - train_start_time }) if self.output_flag: print(f"Iter {iter_count}, loss {total_loss_return.loss}, " + "positivity cost " + - f"{total_loss_return.lyap_positivity_mip_obj}, " + + f"{total_loss_return.lyap_loss.positivity_mip_obj}, " + "derivative_cost " + - f"{total_loss_return.lyap_derivative_mip_obj}") - if total_loss_return.lyap_positivity_mip_obj <=\ + f"{total_loss_return.lyap_loss.derivative_mip_obj}") + if total_loss_return.lyap_loss.positivity_mip_obj <=\ self.lyapunov_positivity_convergence_tol and\ - total_loss_return.lyap_derivative_mip_obj <= \ + total_loss_return.lyap_loss.derivative_mip_obj <= \ self.lyapunov_derivative_convergence_tol: return (True, total_loss_return.loss.item(), - total_loss_return.lyap_positivity_mip_obj, - total_loss_return.lyap_derivative_mip_obj) - if total_loss_return.lyap_positivity_mip_obj < \ + total_loss_return.lyap_loss.positivity_mip_obj, + total_loss_return.lyap_loss.derivative_mip_obj) + if total_loss_return.lyap_loss.positivity_mip_obj < \ self.lyapunov_positivity_convergence_tol and\ - total_loss_return.lyap_derivative_mip_obj <\ + total_loss_return.lyap_loss.derivative_mip_obj <\ best_derivative_mip_cost: best_training_params = [ # noqa p.clone() for p in training_params ] best_derivative_mip_cost = \ - total_loss_return.lyap_derivative_mip_obj + total_loss_return.lyap_loss.derivative_mip_obj total_loss_return.loss.backward() optimizer.step() iter_count += 1 return (False, total_loss_return.loss.item(), - total_loss_return.lyap_positivity_mip_obj, - total_loss_return.lyap_derivative_mip_obj) + total_loss_return.lyap_loss.positivity_mip_obj, + total_loss_return.lyap_loss.derivative_mip_obj) def train_lyapunov_on_samples(self, state_samples_all, num_epochs, batch_size):