Skip to content

Commit

Permalink
Compute loss for barrier training. (#442)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongkai-dai authored Apr 10, 2022
1 parent acc685d commit d7b03d3
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
49 changes: 49 additions & 0 deletions neural_network_lyapunov/test/test_train_lyapunov_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,55 @@ def test_solve_barrier_derivative_mip(self):
adversarial, dut.barrier_x_star, dut.barrier_c,
dut.barrier_epsilon).detach().numpy())

def test_barrier_loss(self):
dut = train_lyapunov_barrier.Trainer()
dut.add_barrier(self.barrier_system,
x_star=(self.system.x_lo * 0.25 +
self.system.x_up * 0.75),
c=0.1,
barrier_epsilon=0.3)
dut.safe_regions = [gurobi_torch_mip.MixedIntegerConstraintsReturn()]
dut.safe_regions[0].Ain = torch.eye(2, dtype=self.dtype)
dut.safe_regions[0].rhs = (self.system.x_lo + self.system.x_up) / 2
safe_state_samples = utils.uniform_sample_in_box(
self.system.x_lo, (self.system.x_lo + self.system.x_up) / 2, 100)

dut.unsafe_regions = [gurobi_torch_mip.MixedIntegerConstraintsReturn()]
dut.unsafe_regions[0].Ain_input = -torch.eye(2, dtype=self.dtype)
dut.unsafe_regions[0].rhs_in = -(self.system.x_lo * 0.25 +
self.system.x_up * 0.75)
num_unsafe_state_samples = 200
unsafe_state_samples = utils.uniform_sample_in_box(
self.system.x_lo * 0.25 + self.system.x_up * 0.75,
self.system.x_up, num_unsafe_state_samples)
num_derivative_state_samples = 300
derivative_state_samples = utils.uniform_sample_in_box(
self.system.x_lo, self.system.x_up, num_derivative_state_samples)

safe_sample_cost_weight = 2.
unsafe_sample_cost_weight = 3.
derivative_sample_cost_weight = 4.
safe_mip_cost_weight = 5.
unsafe_mip_cost_weight = 6.
derivative_mip_cost_weight = 7.

barrier_loss = dut.compute_barrier_loss(
safe_state_samples, unsafe_state_samples, derivative_state_samples,
safe_sample_cost_weight, unsafe_sample_cost_weight,
derivative_sample_cost_weight, safe_mip_cost_weight,
unsafe_mip_cost_weight, derivative_mip_cost_weight)
self.assertEqual(len(barrier_loss.safe_mip_obj), len(dut.safe_regions))
self.assertEqual(len(barrier_loss.safe_mip_loss),
len(dut.safe_regions))
self.assertEqual(len(barrier_loss.unsafe_mip_obj),
len(dut.unsafe_regions))
self.assertEqual(len(barrier_loss.unsafe_mip_loss),
len(dut.unsafe_regions))
self.assertGreater(barrier_loss.unsafe_state_samples.shape[0],
num_unsafe_state_samples)
self.assertGreater(barrier_loss.derivative_state_samples.shape[0],
num_derivative_state_samples)


class TestTrainValueApproximator(unittest.TestCase):
def setUp(self):
Expand Down
80 changes: 80 additions & 0 deletions neural_network_lyapunov/train_lyapunov_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,86 @@ def __init__(self, loss: torch.Tensor, lyap_loss, barrier_loss):
assert (isinstance(barrier_loss, Trainer.BarrierLoss))
self.barrier_loss = barrier_loss

def compute_barrier_loss(self, safe_state_samples, unsafe_state_samples,
derivative_state_samples, safe_sample_cost_weight,
unsafe_sample_cost_weight,
derivative_sample_cost_weight,
safe_mip_cost_weight, unsafe_mip_cost_weight,
derivative_mip_cost_weight) -> BarrierLoss:
barrier_loss = Trainer.BarrierLoss()

if safe_mip_cost_weight is not None:
safe_mip, barrier_loss.safe_mip_obj, safe_mip_adversarial = \
self.solve_barrier_value_mip(safe_flag=True)
if safe_mip_cost_weight != 0:
barrier_loss.safe_mip_loss = [
safe_mip_cost_weight *
mip.compute_objective_from_mip_data_and_solution(
solution_number=0, penalty=1e-13) for mip in safe_mip
]
else:
barrier_loss.safe_mip_loss = None
barrier_loss.safe_mip_obj = None
safe_mip_adversarial = None

if unsafe_mip_cost_weight is not None:
unsafe_mip, barrier_loss.unsafe_mip_obj, unsafe_mip_adversarial = \
self.solve_barrier_value_mip(safe_flag=False)
if unsafe_mip_cost_weight != 0:
barrier_loss.unsafe_mip_loss = [
unsafe_mip_cost_weight *
mip.compute_objective_from_mip_data_and_solution(
solution_number=0, penalty=1e-13) for mip in unsafe_mip
]
else:
barrier_loss.unsafe_mip_loss = None
barrier_loss.unsafe_mip_obj = None
unsafe_mip_adversarial = None

if derivative_mip_cost_weight is not None:
derivative_mip, barrier_loss.derivative_mip_obj, \
derivative_mip_adversarial = self.solve_barrier_derivative_mip(
)
if derivative_mip_cost_weight != 0:
barrier_loss.derivative_mip_loss = derivative_mip_cost_weight \
* derivative_mip.\
compute_objective_from_mip_data_and_solution(
solution_number=0, penalty=1e-13)
else:
barrier_loss.derivative_mip_loss = None
barrier_loss.derivative_mip_obj = None
derivative_mip_adversarial = None

barrier_loss.safe_state_samples = safe_state_samples
barrier_loss.unsafe_state_samples = unsafe_state_samples
barrier_loss.derivative_state_samples = derivative_state_samples
if safe_mip_cost_weight is not None and \
safe_mip_adversarial is not None and \
len(safe_mip_adversarial) > 0:
barrier_loss.safe_state_samples = torch.cat(
(safe_state_samples, torch.cat(safe_mip_adversarial, dim=0)),
dim=0)
if unsafe_mip_cost_weight is not None and \
unsafe_mip_adversarial is not None and \
len(unsafe_mip_adversarial) > 0:
barrier_loss.unsafe_state_samples = torch.cat(
(unsafe_state_samples, torch.cat(unsafe_mip_adversarial,
dim=0)),
dim=0)
if derivative_mip_cost_weight is not None and \
derivative_mip_adversarial is not None:
barrier_loss.derivative_state_samples = torch.cat(
(derivative_state_samples, derivative_mip_adversarial), dim=0)
barrier_loss.safe_sample_loss, barrier_loss.unsafe_sample_loss, \
barrier_loss.derivative_sample_loss = self.barrier_sample_loss(
barrier_loss.safe_state_samples[-self.max_sample_pool_size:],
barrier_loss.unsafe_state_samples[-self.max_sample_pool_size:],
barrier_loss.derivative_state_samples[
-self.max_sample_pool_size:],
safe_sample_cost_weight, unsafe_sample_cost_weight,
derivative_sample_cost_weight)
return barrier_loss

def total_loss(self, positivity_state_samples,
lyap_derivative_state_samples,
lyap_derivative_state_samples_next,
Expand Down

0 comments on commit d7b03d3

Please sign in to comment.