diff --git a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py index bcc16d12..91496b7a 100644 --- a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py +++ b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py @@ -520,6 +520,55 @@ def test_solve_barrier_derivative_mip(self): adversarial, dut.barrier_x_star, dut.barrier_c, dut.barrier_epsilon).detach().numpy()) + def test_barrier_loss(self): + dut = train_lyapunov_barrier.Trainer() + dut.add_barrier(self.barrier_system, + x_star=(self.system.x_lo * 0.25 + + self.system.x_up * 0.75), + c=0.1, + barrier_epsilon=0.3) + dut.safe_regions = [gurobi_torch_mip.MixedIntegerConstraintsReturn()] + dut.safe_regions[0].Ain = torch.eye(2, dtype=self.dtype) + dut.safe_regions[0].rhs = (self.system.x_lo + self.system.x_up) / 2 + safe_state_samples = utils.uniform_sample_in_box( + self.system.x_lo, (self.system.x_lo + self.system.x_up) / 2, 100) + + dut.unsafe_regions = [gurobi_torch_mip.MixedIntegerConstraintsReturn()] + dut.unsafe_regions[0].Ain_input = -torch.eye(2, dtype=self.dtype) + dut.unsafe_regions[0].rhs_in = -(self.system.x_lo * 0.25 + + self.system.x_up * 0.75) + num_unsafe_state_samples = 200 + unsafe_state_samples = utils.uniform_sample_in_box( + self.system.x_lo * 0.25 + self.system.x_up * 0.75, + self.system.x_up, num_unsafe_state_samples) + num_derivative_state_samples = 300 + derivative_state_samples = utils.uniform_sample_in_box( + self.system.x_lo, self.system.x_up, num_derivative_state_samples) + + safe_sample_cost_weight = 2. + unsafe_sample_cost_weight = 3. + derivative_sample_cost_weight = 4. + safe_mip_cost_weight = 5. + unsafe_mip_cost_weight = 6. + derivative_mip_cost_weight = 7. + + barrier_loss = dut.compute_barrier_loss( + safe_state_samples, unsafe_state_samples, derivative_state_samples, + safe_sample_cost_weight, unsafe_sample_cost_weight, + derivative_sample_cost_weight, safe_mip_cost_weight, + unsafe_mip_cost_weight, derivative_mip_cost_weight) + self.assertEqual(len(barrier_loss.safe_mip_obj), len(dut.safe_regions)) + self.assertEqual(len(barrier_loss.safe_mip_loss), + len(dut.safe_regions)) + self.assertEqual(len(barrier_loss.unsafe_mip_obj), + len(dut.unsafe_regions)) + self.assertEqual(len(barrier_loss.unsafe_mip_loss), + len(dut.unsafe_regions)) + self.assertGreater(barrier_loss.unsafe_state_samples.shape[0], + num_unsafe_state_samples) + self.assertGreater(barrier_loss.derivative_state_samples.shape[0], + num_derivative_state_samples) + class TestTrainValueApproximator(unittest.TestCase): def setUp(self): diff --git a/neural_network_lyapunov/train_lyapunov_barrier.py b/neural_network_lyapunov/train_lyapunov_barrier.py index 8859aec4..788c9afd 100644 --- a/neural_network_lyapunov/train_lyapunov_barrier.py +++ b/neural_network_lyapunov/train_lyapunov_barrier.py @@ -623,6 +623,86 @@ def __init__(self, loss: torch.Tensor, lyap_loss, barrier_loss): assert (isinstance(barrier_loss, Trainer.BarrierLoss)) self.barrier_loss = barrier_loss + def compute_barrier_loss(self, safe_state_samples, unsafe_state_samples, + derivative_state_samples, safe_sample_cost_weight, + unsafe_sample_cost_weight, + derivative_sample_cost_weight, + safe_mip_cost_weight, unsafe_mip_cost_weight, + derivative_mip_cost_weight) -> BarrierLoss: + barrier_loss = Trainer.BarrierLoss() + + if safe_mip_cost_weight is not None: + safe_mip, barrier_loss.safe_mip_obj, safe_mip_adversarial = \ + self.solve_barrier_value_mip(safe_flag=True) + if safe_mip_cost_weight != 0: + barrier_loss.safe_mip_loss = [ + safe_mip_cost_weight * + mip.compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) for mip in safe_mip + ] + else: + barrier_loss.safe_mip_loss = None + barrier_loss.safe_mip_obj = None + safe_mip_adversarial = None + + if unsafe_mip_cost_weight is not None: + unsafe_mip, barrier_loss.unsafe_mip_obj, unsafe_mip_adversarial = \ + self.solve_barrier_value_mip(safe_flag=False) + if unsafe_mip_cost_weight != 0: + barrier_loss.unsafe_mip_loss = [ + unsafe_mip_cost_weight * + mip.compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) for mip in unsafe_mip + ] + else: + barrier_loss.unsafe_mip_loss = None + barrier_loss.unsafe_mip_obj = None + unsafe_mip_adversarial = None + + if derivative_mip_cost_weight is not None: + derivative_mip, barrier_loss.derivative_mip_obj, \ + derivative_mip_adversarial = self.solve_barrier_derivative_mip( + ) + if derivative_mip_cost_weight != 0: + barrier_loss.derivative_mip_loss = derivative_mip_cost_weight \ + * derivative_mip.\ + compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) + else: + barrier_loss.derivative_mip_loss = None + barrier_loss.derivative_mip_obj = None + derivative_mip_adversarial = None + + barrier_loss.safe_state_samples = safe_state_samples + barrier_loss.unsafe_state_samples = unsafe_state_samples + barrier_loss.derivative_state_samples = derivative_state_samples + if safe_mip_cost_weight is not None and \ + safe_mip_adversarial is not None and \ + len(safe_mip_adversarial) > 0: + barrier_loss.safe_state_samples = torch.cat( + (safe_state_samples, torch.cat(safe_mip_adversarial, dim=0)), + dim=0) + if unsafe_mip_cost_weight is not None and \ + unsafe_mip_adversarial is not None and \ + len(unsafe_mip_adversarial) > 0: + barrier_loss.unsafe_state_samples = torch.cat( + (unsafe_state_samples, torch.cat(unsafe_mip_adversarial, + dim=0)), + dim=0) + if derivative_mip_cost_weight is not None and \ + derivative_mip_adversarial is not None: + barrier_loss.derivative_state_samples = torch.cat( + (derivative_state_samples, derivative_mip_adversarial), dim=0) + barrier_loss.safe_sample_loss, barrier_loss.unsafe_sample_loss, \ + barrier_loss.derivative_sample_loss = self.barrier_sample_loss( + barrier_loss.safe_state_samples[-self.max_sample_pool_size:], + barrier_loss.unsafe_state_samples[-self.max_sample_pool_size:], + barrier_loss.derivative_state_samples[ + -self.max_sample_pool_size:], + safe_sample_cost_weight, unsafe_sample_cost_weight, + derivative_sample_cost_weight) + return barrier_loss + def total_loss(self, positivity_state_samples, lyap_derivative_state_samples, lyap_derivative_state_samples_next,