From 18cdc0ec8ba4751a96baff789a59b18068bd7725 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Wed, 20 Mar 2024 11:45:23 -0400 Subject: [PATCH 01/22] fix(jax): rename os env, reload module to make tests work together --- examples/stl/differentiability.py | 15 +++++++++------ src/ds/stl.py | 5 +++-- src/ds/stl_jax.py | 22 +++++++++++++--------- src/ds/utils.py | 2 +- tests/test_stl.py | 12 +++++++++--- tests/test_stl_jax.py | 13 +++++++++---- 6 files changed, 44 insertions(+), 25 deletions(-) diff --git a/examples/stl/differentiability.py b/examples/stl/differentiability.py index 10190d1..7f6b774 100644 --- a/examples/stl/differentiability.py +++ b/examples/stl/differentiability.py @@ -4,17 +4,20 @@ import matplotlib.pyplot as plt import numpy as np import optax +import importlib + +import ds.utils as ds_utils # if JAX_BACKEND is set the import will be from jax.numpy -if os.environ.get("JAX_STL_BACKEND") == "jax": +if os.environ.get("DIFF_STL_BACKEND") == "jax": print("Using JAX backend") from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate - from ds.utils import default_tensor + importlib.reload(ds_utils) # Reload the module to reset the backend import jax else: print("Using PyTorch backend") from ds.stl import STL, RectAvoidPredicate, RectReachPredicate - from ds.utils import default_tensor + importlib.reload(ds_utils) # Reload the module to reset the backend import torch from torch.optim import Adam @@ -33,7 +36,7 @@ def eval_reach_avoid(mute=False): form = goal.eventually(0, 10) & obs.always(0, 10) # Define 2 initial paths in batch - path_1 = default_tensor( + path_1 = ds_utils.default_tensor( np.array( [ [ @@ -98,7 +101,7 @@ def backward(mute=True): # and that holds always in 0 to 8 # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) - path = default_tensor( + path = ds_utils.default_tensor( np.array( [ [ @@ -124,7 +127,7 @@ def backward(mute=True): lr = 0.1 num_iterations = 1000 - if os.environ.get("JAX_STL_BACKEND") == "jax": + if os.environ.get("DIFF_STL_BACKEND") == "jax": solver = optax.adam(lr) var_solver_state = solver.init(path) diff --git a/src/ds/stl.py b/src/ds/stl.py index ec90766..7595dc8 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -369,6 +369,7 @@ def solve_stlpy_formula( class PredicateBase: def __init__(self, name: str): self.name = name + self.logger = logging.getLogger(__name__) def eval_at_t(self, path: Tensor, t: int = 0) -> Tensor: return self.eval_whole_path(path, t, t + 1)[:, 0] @@ -404,8 +405,8 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: self.cent_tensor = default_tensor(cent) self.size_tensor = default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative - print(f"shrink factor: {shrink_factor}") + self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative in STLpy + self.logger.info(f"shrink factor: {shrink_factor}") def eval_whole_path( self, path: Tensor, start_t: int = 0, end_t: int = None diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index d6494d3..6fde929 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -1,3 +1,4 @@ +import importlib import io import os from abc import abstractmethod @@ -7,11 +8,13 @@ import jax import numpy as np +import importlib from jax.nn import softmax from stlpy.STL import LinearPredicate, STLTree -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes -from ds.utils import default_tensor +os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes +import ds.utils as ds_utils +importlib.reload(ds_utils) # Reload the module to change the backend with redirect_stdout(io.StringIO()): pass @@ -27,6 +30,7 @@ class PredicateBase: def __init__(self, name: str): self.name = name + self.logger = logging.getLogger(__name__) def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: return self.eval_whole_path(path, t, t + 1)[:, 0] @@ -60,10 +64,10 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: self.cent = cent self.size = size - self.cent_tensor = default_tensor(cent) - self.size_tensor = default_tensor(size) + self.cent_tensor = ds_utils.default_tensor(cent) + self.size_tensor = ds_utils.default_tensor(size) self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative - print(f"shrink factor: {shrink_factor}") + self.logger.info(f"shrink factor: {shrink_factor}") def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None @@ -98,8 +102,8 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): self.cent = cent self.size = size - self.cent_tensor = default_tensor(cent) - self.size_tensor = default_tensor(size) + self.cent_tensor = ds_utils.default_tensor(cent) + self.size_tensor = ds_utils.default_tensor(size) def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None @@ -444,7 +448,7 @@ def _eval_until( for i in range(cond.shape[0]): cond[i, index[i]:] = 1.0 cond = ~cond.bool() - till_pred = jnp.where(cond, till_pred, default_tensor(1)) + till_pred = jnp.where(cond, till_pred, ds_utils.default_tensor(1)) if self._is_leaf(sub_form1): res = sub_form1.eval_whole_path(path[:, start_t:end_t]) @@ -456,7 +460,7 @@ def _eval_until( ], axis=-1, ) - res = jnp.where(cond, res, default_tensor(-1)) + res = jnp.where(cond, res, ds_utils.default_tensor(-1)) # when cond < 0, res should always > 0 to be hold return self._tensor_min(-res * till_pred, axis=-1) diff --git a/src/ds/utils.py b/src/ds/utils.py index 88040ea..d4ffcaa 100644 --- a/src/ds/utils.py +++ b/src/ds/utils.py @@ -3,7 +3,7 @@ import numpy as np # if JAX_BACKEND is set the import will be from jax.numpy -if os.environ.get("JAX_STL_BACKEND") == "jax": +if os.environ.get("DIFF_STL_BACKEND") == "jax": # print("Using JAX backend") import jax diff --git a/tests/test_stl.py b/tests/test_stl.py index 27463c8..54755d7 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -1,22 +1,28 @@ +import importlib +import os import unittest -from examples.stl.differentiability import eval_reach_avoid, backward +import examples.stl.differentiability as stl_diff_examples class TestExamples(unittest.TestCase): + def setUp(self): + os.environ["DIFF_STL_BACKEND"] = "" + importlib.reload(stl_diff_examples) # Reload the module to reset the backend + def test_run(self): # Test Eval final_result = [] for _ in range(1000): # Fair test with jax - res = eval_reach_avoid(mute=True) + res = stl_diff_examples.eval_reach_avoid(mute=True) final_result.append(res) print(final_result) # Test differentiability - path = backward() + path = stl_diff_examples.backward() print(path) # diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index a27bfad..47cc1d1 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -1,13 +1,18 @@ +import importlib import os import unittest -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes -from examples.stl.differentiability import eval_reach_avoid, backward from jax import jit +import examples.stl.differentiability as stl_diff_examples + class TestJAXExamples(unittest.TestCase): + def setUp(self): + os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes + importlib.reload(stl_diff_examples) # Reload the module to reset the backend + def test_run(self): # TODO: Study jit decorator and see optimizations # jit(eval_reach_avoid)() @@ -15,13 +20,13 @@ def test_run(self): final_result = [] for _ in range(1000): # Magic of jax - res = jit(eval_reach_avoid)() + res = jit(stl_diff_examples.eval_reach_avoid)() final_result.append(res) print(final_result) # Test differentiability - path = backward() + path = stl_diff_examples.backward() print(path) # (jax.lax.fori_loop(0, 1000, lambda i, _: jit(eval_reach_avoid)(), None)).block_until_ready() # for _ in range(1000): From 47e75ae594490c7e883244c03565f1e7daac54b2 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Wed, 20 Mar 2024 12:25:22 -0400 Subject: [PATCH 02/22] fix(jax): add evaluation shape test (which is failing). TODO: Debug --- README.md | 2 +- tests/test_stl.py | 42 ++++++++++++++++++++++++++++++++++++++++- tests/test_stl_jax.py | 44 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fba7348..212cee9 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ First set the backend to JAX: ```python import os -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX +os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX ``` Then you can use the JAX backend to optimize the inputs to satisfy the formula. diff --git a/tests/test_stl.py b/tests/test_stl.py index 54755d7..8202535 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -3,13 +3,16 @@ import unittest import examples.stl.differentiability as stl_diff_examples - +from ds.stl import STL, RectAvoidPredicate, RectReachPredicate +import numpy as np +import ds.utils as ds_utils class TestExamples(unittest.TestCase): def setUp(self): os.environ["DIFF_STL_BACKEND"] = "" importlib.reload(stl_diff_examples) # Reload the module to reset the backend + importlib.reload(ds_utils) # Reload the module to reset the backend def test_run(self): # Test Eval @@ -28,6 +31,43 @@ def test_run(self): # # self.assertEqual(True, False) # add assertion here + def test_evaluations(self, num_tiles = 3): + """Run simple evaluations to test shapes and types""" + goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) + # goal_2 is a rectangle area centered in [2, 2] with width and height 1 + goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + form = goal_1.eventually(0, 5) + path = ds_utils.default_tensor( + np.array( + [ + [ + [1, 0], + [1, 0], + [1, 0], + [0, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] + ) + ) + + loss = form.eval(path.tile(num_tiles,1,1)) + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index 47cc1d1..a587b59 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -2,9 +2,13 @@ import os import unittest +import jax +import numpy as np from jax import jit +import ds.utils as ds_utils import examples.stl.differentiability as stl_diff_examples +from ds.stl_jax import STL, RectReachPredicate class TestJAXExamples(unittest.TestCase): @@ -12,6 +16,7 @@ class TestJAXExamples(unittest.TestCase): def setUp(self): os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes importlib.reload(stl_diff_examples) # Reload the module to reset the backend + importlib.reload(ds_utils) # Reload the module to reset the backend def test_run(self): # TODO: Study jit decorator and see optimizations @@ -34,6 +39,45 @@ def test_run(self): # # self.assertEqual(True, False) # add assertion here + def test_evaluations(self, num_tiles=3): + """Run simple evaluations to test shapes and types""" + goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) + # goal_2 is a rectangle area centered in [2, 2] with width and height 1 + goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + form = goal_1.eventually(0, 5) + path = ds_utils.default_tensor( + np.array( + [ + [ + [1, 0], + [1, 0], + [1, 0], + [0, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] + ) + ) + + (loss), grad = jax.value_and_grad(form.eval)( + jax.numpy.tile(path, (100, 2, 1)) + ) + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + if __name__ == '__main__': unittest.main() From 885e4d3c47d99371c488ab3d7af6b98f9035f79a Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Wed, 20 Mar 2024 17:05:44 -0400 Subject: [PATCH 03/22] fix(jax): fix evaluation bug in predicates, jnp.min and torch.min had different semantics. Fix examples to allow different batch sizes as input to STL.eval --- examples/stl/differentiability.py | 4 ++-- src/ds/stl_jax.py | 4 ++-- tests/test_stl.py | 12 +++++++----- tests/test_stl_jax.py | 6 ++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/stl/differentiability.py b/examples/stl/differentiability.py index 7f6b774..896849b 100644 --- a/examples/stl/differentiability.py +++ b/examples/stl/differentiability.py @@ -135,10 +135,10 @@ def backward(mute=True): @jax.jit def train_step(params, solver_state): # Performs a one step update. - (loss), grad = jax.value_and_grad(form.eval)( + (loss), grad = jax.value_and_grad(lambda x : -form.eval(x).mean())( params ) - updates, solver_state = solver.update(-grad, solver_state) + updates, solver_state = solver.update(grad, solver_state) params = optax.apply_updates(params, updates) return params, solver_state, loss diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 6fde929..f06d012 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -76,7 +76,7 @@ def eval_whole_path( eval_path = path[:, start_t:end_t] res = jnp.min( self.size_tensor / 2 - jnp.abs(eval_path - self.cent_tensor), axis=-1 - )[0] + ) return res @@ -112,7 +112,7 @@ def eval_whole_path( eval_path = path[:, start_t:end_t] res = jnp.max( jnp.abs(eval_path - self.cent_tensor) - self.size_tensor / 2, axis=-1 - )[0] + ) return res diff --git a/tests/test_stl.py b/tests/test_stl.py index 8202535..6e7dec8 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -2,10 +2,12 @@ import os import unittest -import examples.stl.differentiability as stl_diff_examples -from ds.stl import STL, RectAvoidPredicate, RectReachPredicate import numpy as np + import ds.utils as ds_utils +import examples.stl.differentiability as stl_diff_examples +from ds.stl import STL, RectReachPredicate + class TestExamples(unittest.TestCase): @@ -31,7 +33,7 @@ def test_run(self): # # self.assertEqual(True, False) # add assertion here - def test_evaluations(self, num_tiles = 3): + def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types""" goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) # goal_2 is a rectangle area centered in [2, 2] with width and height 1 @@ -40,7 +42,7 @@ def test_evaluations(self, num_tiles = 3): # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 # and that holds always in 0 to 8 # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 - form = goal_1.eventually(0, 5) + form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) path = ds_utils.default_tensor( np.array( [ @@ -64,7 +66,7 @@ def test_evaluations(self, num_tiles = 3): ) ) - loss = form.eval(path.tile(num_tiles,1,1)) + loss = form.eval(path.tile(num_tiles, 1, 1)) # Make a batch of size num_tiles self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index a587b59..649fa9b 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -48,7 +48,7 @@ def test_evaluations(self, num_tiles=3): # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 # and that holds always in 0 to 8 # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 - form = goal_1.eventually(0, 5) + form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) path = ds_utils.default_tensor( np.array( [ @@ -72,9 +72,7 @@ def test_evaluations(self, num_tiles=3): ) ) - (loss), grad = jax.value_and_grad(form.eval)( - jax.numpy.tile(path, (100, 2, 1)) - ) + loss = form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") From 677ba065829ada646ab8879756d62960852b19d2 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Thu, 21 Mar 2024 11:33:44 -0400 Subject: [PATCH 04/22] fix(jax) : change a few more numpy functions -> jax.numpy --- src/ds/stl_jax.py | 45 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index f06d012..d98aa66 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -10,7 +10,7 @@ import numpy as np import importlib from jax.nn import softmax -from stlpy.STL import LinearPredicate, STLTree +from stlpy.STL import LinearPredicate as baseLinearPredicate, STLTree os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes import ds.utils as ds_utils @@ -154,12 +154,12 @@ def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): y1_min, y1_max, y2_min, y2_max = bounds # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) + a1 = jnp.zeros((1, d)) a1[:, y1_index] = 1 right = LinearPredicate(a1, y1_min) left = LinearPredicate(-a1, -y1_max) - a2 = np.zeros((1, d)) + a2 = jnp.zeros((1, d)) a2[:, y2_index] = 1 top = LinearPredicate(a2, y2_min) bottom = LinearPredicate(-a2, -y2_max) @@ -206,12 +206,12 @@ def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): y1_min, y1_max, y2_min, y2_max = bounds # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) + a1 = jnp.zeros((1, d)) a1[:, y1_index] = 1 right = LinearPredicate(a1, y1_max) left = LinearPredicate(-a1, -y1_min) - a2 = np.zeros((1, d)) + a2 = jnp.zeros((1, d)) a2[:, y2_index] = 1 top = LinearPredicate(a2, y2_max) bottom = LinearPredicate(-a2, -y2_min) @@ -227,12 +227,47 @@ def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): return outside_rectangle +class LinearPredicate(baseLinearPredicate): + """ + A linear STL predicate :math:`\pi` defined by + + .. math:: + + a^Ty_t - b \geq 0 + + where :math:`y_t \in \mathbb{R}^d` is the value of the signal + at a given timestep :math:`t`, :math:`a \in \mathbb{R}^d`, + and :math:`b \in \mathbb{R}`. + + :param a: a jax numpy array or list representing the vector :math:`a` + :param b: a list, jax numpy array, or scalar representing :math:`b` + :param name: (optional) a string used to identify this predicate. + """ + + def __init__(self, a, b, name=None): + # Convert provided constraints to numpy arrays + self.a = jnp.asarray(a).reshape((-1, 1)) + self.b = jnp.atleast_1d(b) + + # Some dimension-related sanity checks + assert (self.a.shape[1] == 1), "a must be of shape (d,1)" + assert (self.b.shape == (1,)), "b must be of shape (1,)" + + # Store the dimensionality of y_t + self.d = self.a.shape[0] + + # A unique string describing this predicate + self.name = name + + AST = TypeVar("AST", list, PredicateBase) class STL: """ Class for representing STL formulas. + + All methods are functionally pure with no side effects during execution. """ def __init__(self, ast: AST): From b9a615d21b99a8b89b9aad5115ff3cef90afa7c3 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Thu, 21 Mar 2024 11:39:01 -0400 Subject: [PATCH 05/22] docs(jax) : update jax instructions and changelog --- CHANGELOG.md | 15 +++++++++++++++ README.md | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bab28c..9c847b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +#### 2024-03-21 + +##### Bug Fixes + +* **jax:** + * fix evaluation bug in predicates, jnp.min and torch.min had different semantics. Fix examples to allow different + batch sizes as input to STL.eval (885e4d3c) + * add evaluation shape test (which is failing). TODO: Debug (47e75ae5) + * rename os env, reload module to make tests work together (18cdc0ec) + * Using the optax (25687960) + +##### Tests + +* **jax:** basic jit jax optimization (a930a76e) + #### 2024-02-24 ##### Tests diff --git a/README.md b/README.md index 212cee9..1d72263 100644 --- a/README.md +++ b/README.md @@ -67,12 +67,12 @@ Probability temporal logic is an ongoing work integrating probability and random If you are using JAX, you can use the JAX backend (stl_jax) and gain immense speedups in many cases. -First set the backend to JAX: +First set the backend to JAX using Environment Variables for our utility functions: ```python import os -os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX +os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX (if unset or any other value uses the PyTorch backend) ``` Then you can use the JAX backend to optimize the inputs to satisfy the formula. From e007146082d6beac839e0c101d4fa6191868d461 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 30 Mar 2024 13:35:03 +0000 Subject: [PATCH 06/22] move common fns to utils so torch is optional --- src/ds/stl.py | 31 +------------------------------ src/ds/stl_jax.py | 3 ++- src/ds/utils.py | 29 +++++++++++++++++++++++++++++ tests/test_stl_jax.py | 2 ++ 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/ds/stl.py b/src/ds/stl.py index 7595dc8..facdc29 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -2,7 +2,6 @@ import time from abc import abstractmethod from collections import deque -from contextlib import contextmanager from contextlib import redirect_stdout from typing import TypeVar, Tuple @@ -15,41 +14,13 @@ from torch import Tensor from torch.nn.functional import softmax -from ds.utils import default_tensor +from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, set_hardness with redirect_stdout(io.StringIO()): from stlpy.solvers.base import STLSolver import logging -COLORED = False -IMPLIES_TRICK = False -HARDNESS = 100.0 # Reduce hardness of softmax to propagate gradients more easily - - -@contextmanager -def set_hardness(hardness: float): - """Set the hardness of the softmax function for the duration of the context. - Useful for making evaluation strict while allowing gradients to pass through during training. - - :param hardness: hardness of the softmax function - :type hardness: float - """ - global HARDNESS - old_hardness = HARDNESS - HARDNESS = hardness - yield - HARDNESS = old_hardness - - -if COLORED: - from termcolor import colored -else: - - def colored(text, color): - return text - - class GurobiMICPSolver(STLSolver): """ Given an :class:`.STLFormula` :math:`\\varphi` and a :class:`.LinearSystem`, diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index d98aa66..b891ccb 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -21,7 +21,8 @@ import logging -from .stl import colored, HARDNESS, IMPLIES_TRICK, set_hardness +colored, HARDNESS, IMPLIES_TRICK, set_hardness = ds_utils.colored, ds_utils.HARDNESS, ds_utils.IMPLIES_TRICK, ds_utils.set_hardness + # Replace with JAX import jax.numpy as jnp diff --git a/src/ds/utils.py b/src/ds/utils.py index d4ffcaa..cd2d21c 100644 --- a/src/ds/utils.py +++ b/src/ds/utils.py @@ -1,7 +1,36 @@ import os +from contextlib import contextmanager import numpy as np + +COLORED = False +IMPLIES_TRICK = False +HARDNESS = 100.0 # Reduce hardness of softmax to propagate gradients more easily + + +@contextmanager +def set_hardness(hardness: float): + """Set the hardness of the softmax function for the duration of the context. + Useful for making evaluation strict while allowing gradients to pass through during training. + + :param hardness: hardness of the softmax function + :type hardness: float + """ + global HARDNESS + old_hardness = HARDNESS + HARDNESS = hardness + yield + HARDNESS = old_hardness + + +if COLORED: + from termcolor import colored +else: + + def colored(text, color): + return text + # if JAX_BACKEND is set the import will be from jax.numpy if os.environ.get("DIFF_STL_BACKEND") == "jax": # print("Using JAX backend") diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index 649fa9b..39e14fe 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -6,6 +6,8 @@ import numpy as np from jax import jit +os.environ["DIFF_STL_BACKEND"] = "jax" # So ds_utils does not require torch + import ds.utils as ds_utils import examples.stl.differentiability as stl_diff_examples from ds.stl_jax import STL, RectReachPredicate From 35a2cd01eff83055332e8af2646234595698c2ce Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Sat, 13 Apr 2024 19:21:35 -0400 Subject: [PATCH 07/22] fix(stl): fix for using StlPySolver with JAX (temporary conversion to numpy) --- src/ds/stl.py | 117 +++------------------------------------------- src/ds/stl_jax.py | 25 +++++----- src/ds/utils.py | 111 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 128 insertions(+), 125 deletions(-) diff --git a/src/ds/stl.py b/src/ds/stl.py index facdc29..37108c0 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -1,26 +1,27 @@ +import gurobipy as gp import io +import numpy as np import time +import torch from abc import abstractmethod from collections import deque from contextlib import redirect_stdout -from typing import TypeVar, Tuple - -import gurobipy as gp -import numpy as np -import torch from gurobipy import GRB from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree from stlpy.systems import LinearSystem from torch import Tensor from torch.nn.functional import softmax +from typing import TypeVar, Tuple -from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, set_hardness +from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, set_hardness, outside_rectangle_formula, \ + inside_rectangle_formula with redirect_stdout(io.StringIO()): from stlpy.solvers.base import STLSolver import logging + class GurobiMICPSolver(STLSolver): """ Given an :class:`.STLFormula` :math:`\\varphi` and a :class:`.LinearSystem`, @@ -433,110 +434,6 @@ def get_stlpy_form(self) -> STLTree: return outside_rectangle_formula(bounds, 0, 1, 2, self.name) -def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): - """ - Create an STL formula representing being inside a - rectangle with the given bounds: - - :: - - y2_max +-------------------+ - | | - | | - | | - y2_min +-------------------+ - y1_min y1_max - - :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing - the bounds of the rectangle. - :param y1_index: index of the first (``y1``) dimension - :param y2_index: index of the second (``y2``) dimension - :param d: dimension of the overall signal - :param name: (optional) string describing this formula - - :return inside_rectangle: An ``STLFormula`` specifying being inside the - rectangle at time zero. - """ - assert y1_index < d, "index must be less than signal dimension" - assert y2_index < d, "index must be less than signal dimension" - - # Unpack the bounds - y1_min, y1_max, y2_min, y2_max = bounds - - # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 - right = LinearPredicate(a1, y1_min) - left = LinearPredicate(-a1, -y1_max) - - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 - top = LinearPredicate(a2, y2_min) - bottom = LinearPredicate(-a2, -y2_max) - - # Take the conjuction across all the sides - inside_rectangle = right & left & top & bottom - - # set the names - if name is not None: - inside_rectangle.__str__ = lambda: name - inside_rectangle.__repr__ = lambda: name - - return inside_rectangle - - -def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): - """ - Create an STL formula representing being outside a - rectangle with the given bounds: - - :: - - y2_max +-------------------+ - | | - | | - | | - y2_min +-------------------+ - y1_min y1_max - - :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing - the bounds of the rectangle. - :param y1_index: index of the first (``y1``) dimension - :param y2_index: index of the second (``y2``) dimension - :param d: dimension of the overall signal - :param name: (optional) string describing this formula - - :return outside_rectangle: An ``STLFormula`` specifying being outside the - rectangle at time zero. - """ - assert y1_index < d, "index must be less than signal dimension" - assert y2_index < d, "index must be less than signal dimension" - - # Unpack the bounds - y1_min, y1_max, y2_min, y2_max = bounds - - # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 - right = LinearPredicate(a1, y1_max) - left = LinearPredicate(-a1, -y1_min) - - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 - top = LinearPredicate(a2, y2_max) - bottom = LinearPredicate(-a2, -y2_min) - - # Take the disjuction across all the sides - outside_rectangle = right | left | top | bottom - - # set the names - if name is not None: - outside_rectangle.__str__ = lambda: name - outside_rectangle.__repr__ = lambda: name - - return outside_rectangle - - AST = TypeVar("AST", list, PredicateBase) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index b891ccb..13560b5 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -1,19 +1,19 @@ import importlib +import importlib import io +import jax +import numpy as np import os from abc import abstractmethod from collections import deque from contextlib import redirect_stdout -from typing import TypeVar - -import jax -import numpy as np -import importlib from jax.nn import softmax from stlpy.STL import LinearPredicate as baseLinearPredicate, STLTree +from typing import TypeVar os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes import ds.utils as ds_utils + importlib.reload(ds_utils) # Reload the module to change the backend with redirect_stdout(io.StringIO()): @@ -22,7 +22,8 @@ import logging colored, HARDNESS, IMPLIES_TRICK, set_hardness = ds_utils.colored, ds_utils.HARDNESS, ds_utils.IMPLIES_TRICK, ds_utils.set_hardness - +outside_npy = ds_utils.outside_rectangle_formula +inside_npy = ds_utils.inside_rectangle_formula # Replace with JAX import jax.numpy as jnp @@ -85,7 +86,7 @@ def get_stlpy_form(self) -> STLTree: bounds = np.stack( [self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2] ).T.flatten() - return inside_rectangle_formula(bounds, 0, 1, 2, self.name) + return inside_npy(bounds, 0, 1, 2, self.name) class RectAvoidPredicate(PredicateBase): @@ -121,7 +122,7 @@ def get_stlpy_form(self) -> STLTree: bounds = np.stack( [self.cent - self.size / 2, self.cent + self.size / 2] ).T.flatten() - return outside_rectangle_formula(bounds, 0, 1, 2, self.name) + return outside_npy(bounds, 0, 1, 2, self.name) def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): @@ -156,12 +157,12 @@ def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): # Create predicates a*y >= b for each side of the rectangle a1 = jnp.zeros((1, d)) - a1[:, y1_index] = 1 + a1.at[:, y1_index].set(1) right = LinearPredicate(a1, y1_min) left = LinearPredicate(-a1, -y1_max) a2 = jnp.zeros((1, d)) - a2[:, y2_index] = 1 + a2.at[:, y2_index].set(1) top = LinearPredicate(a2, y2_min) bottom = LinearPredicate(-a2, -y2_max) @@ -208,12 +209,12 @@ def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): # Create predicates a*y >= b for each side of the rectangle a1 = jnp.zeros((1, d)) - a1[:, y1_index] = 1 + a1.at[:, y1_index].set(1) right = LinearPredicate(a1, y1_max) left = LinearPredicate(-a1, -y1_min) a2 = jnp.zeros((1, d)) - a2[:, y2_index] = 1 + a2.at[:, y2_index].set(1) top = LinearPredicate(a2, y2_max) bottom = LinearPredicate(-a2, -y2_min) diff --git a/src/ds/utils.py b/src/ds/utils.py index cd2d21c..a84b97d 100644 --- a/src/ds/utils.py +++ b/src/ds/utils.py @@ -1,8 +1,6 @@ +import numpy as np import os - from contextlib import contextmanager -import numpy as np - COLORED = False IMPLIES_TRICK = False @@ -31,6 +29,113 @@ def set_hardness(hardness: float): def colored(text, color): return text +from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree + + +def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): + """ + Create an STL formula representing being inside a + rectangle with the given bounds: + + :: + + y2_max +-------------------+ + | | + | | + | | + y2_min +-------------------+ + y1_min y1_max + + :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing + the bounds of the rectangle. + :param y1_index: index of the first (``y1``) dimension + :param y2_index: index of the second (``y2``) dimension + :param d: dimension of the overall signal + :param name: (optional) string describing this formula + + :return inside_rectangle: An ``STLFormula`` specifying being inside the + rectangle at time zero. + """ + assert y1_index < d, "index must be less than signal dimension" + assert y2_index < d, "index must be less than signal dimension" + + # Unpack the bounds + y1_min, y1_max, y2_min, y2_max = bounds + + # Create predicates a*y >= b for each side of the rectangle + a1 = np.zeros((1, d)) + a1[:, y1_index] = 1 + right = LinearPredicate(a1, y1_min) + left = LinearPredicate(-a1, -y1_max) + + a2 = np.zeros((1, d)) + a2[:, y2_index] = 1 + top = LinearPredicate(a2, y2_min) + bottom = LinearPredicate(-a2, -y2_max) + + # Take the conjuction across all the sides + inside_rectangle = right & left & top & bottom + + # set the names + if name is not None: + inside_rectangle.__str__ = lambda: name + inside_rectangle.__repr__ = lambda: name + + return inside_rectangle + + +def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): + """ + Create an STL formula representing being outside a + rectangle with the given bounds: + + :: + + y2_max +-------------------+ + | | + | | + | | + y2_min +-------------------+ + y1_min y1_max + + :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing + the bounds of the rectangle. + :param y1_index: index of the first (``y1``) dimension + :param y2_index: index of the second (``y2``) dimension + :param d: dimension of the overall signal + :param name: (optional) string describing this formula + + :return outside_rectangle: An ``STLFormula`` specifying being outside the + rectangle at time zero. + """ + assert y1_index < d, "index must be less than signal dimension" + assert y2_index < d, "index must be less than signal dimension" + + # Unpack the bounds + y1_min, y1_max, y2_min, y2_max = bounds + + # Create predicates a*y >= b for each side of the rectangle + a1 = np.zeros((1, d)) + a1[:, y1_index] = 1 + right = LinearPredicate(a1, y1_max) + left = LinearPredicate(-a1, -y1_min) + + a2 = np.zeros((1, d)) + a2[:, y2_index] = 1 + top = LinearPredicate(a2, y2_max) + bottom = LinearPredicate(-a2, -y2_min) + + # Take the disjuction across all the sides + outside_rectangle = right | left | top | bottom + + # set the names + if name is not None: + outside_rectangle.__str__ = lambda: name + outside_rectangle.__repr__ = lambda: name + + return outside_rectangle + + # if JAX_BACKEND is set the import will be from jax.numpy if os.environ.get("DIFF_STL_BACKEND") == "jax": # print("Using JAX backend") From 8e5bf1e0dad94b1e3e592c59b2c70134aa472606 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Mon, 15 Apr 2024 11:47:42 -0400 Subject: [PATCH 08/22] docs(stl): add comments to STLPy tools --- src/ds/stl_jax.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 13560b5..29930a9 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -41,10 +41,12 @@ def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: def eval_whole_path( self, path: jnp.ndarray, start_t: int = 0, end_t: int = None ) -> jnp.ndarray: + """Stick to JAX when possible.""" raise NotImplementedError @abstractmethod def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" raise NotImplementedError def __str__(self) -> str: @@ -74,6 +76,7 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: + """Stick to JAX when possible.""" assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] res = jnp.min( @@ -83,6 +86,7 @@ def eval_whole_path( return res def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" bounds = np.stack( [self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2] ).T.flatten() @@ -110,6 +114,7 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: + """Stick to JAX when possible.""" assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] res = jnp.max( @@ -119,6 +124,7 @@ def eval_whole_path( return res def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" bounds = np.stack( [self.cent - self.size / 2, self.cent + self.size / 2] ).T.flatten() From 74828db7cbbe37108f87a5b41b41663505bff2c6 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Wed, 15 May 2024 15:34:18 -0400 Subject: [PATCH 09/22] feature(stl): add loop and branch specs. Fix always end_time. --- src/ds/stl_jax.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 29930a9..2d9522b 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -1,15 +1,14 @@ import importlib -import importlib import io -import jax -import numpy as np import os from abc import abstractmethod from collections import deque from contextlib import redirect_stdout +from typing import TypeVar + +import numpy as np from jax.nn import softmax from stlpy.STL import LinearPredicate as baseLinearPredicate, STLTree -from typing import TypeVar os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes import ds.utils as ds_utils @@ -335,6 +334,9 @@ def _get_end_time(self, ast: AST) -> int: """Get max time of the formula. Runs in O(n) time where n is the number of nodes. Runs once then memoizes.""" if self._is_leaf(ast): return 1 + if ast[0] == "G": + # Add end time from inner formula since always is unrolled + return ast[-1] + self._get_end_time(ast[1]) if ast[0] in self.sequence_operators: # The last two elements are the start and end times return ast[-1] From 3a803caa419d28cf5927dbc4ffef8ff6968ceb43 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Mon, 27 May 2024 13:04:34 -0400 Subject: [PATCH 10/22] feature(realrobot): add script to save output to npy --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 661341a..31eeb4d 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/.idea/diff-spec.iml +/.idea/modules.xml +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/workspace.xml From b1848c8488b8af2392dd17236fa8ba9d1d1196ba Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Fri, 31 May 2024 22:23:54 -0400 Subject: [PATCH 11/22] feature(stl): add loop tests and stlpy solver tests --- tests/test_stl_jax.py | 54 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index 39e14fe..cfdfe01 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -12,7 +12,7 @@ import examples.stl.differentiability as stl_diff_examples from ds.stl_jax import STL, RectReachPredicate - +from ds.stl import StlpySolver class TestJAXExamples(unittest.TestCase): def setUp(self): @@ -20,6 +20,16 @@ def setUp(self): importlib.reload(stl_diff_examples) # Reload the module to reset the backend importlib.reload(ds_utils) # Reload the module to reset the backend + self.goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) + # goal_2 is a rectangle area centered in [2, 2] with width and height 1 + self.goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + self.form = (self.goal_1.eventually(0, 5) & self.goal_2.eventually(0, 5)).always(0, 8) + self.loop_form = (self.goal_1.eventually(0, 4) & self.goal_2.eventually(0, 4)).always(0, 8) + def test_run(self): # TODO: Study jit decorator and see optimizations # jit(eval_reach_avoid)() @@ -43,14 +53,6 @@ def test_run(self): def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types""" - goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) - # goal_2 is a rectangle area centered in [2, 2] with width and height 1 - goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) - - # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 - # and that holds always in 0 to 8 - # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 - form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) path = ds_utils.default_tensor( np.array( [ @@ -74,10 +76,42 @@ def test_evaluations(self, num_tiles=3): ) ) - loss = form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + loss = self.form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + def test_loop(self): + # Test loop spec + num_tiles = 4 + path = ds_utils.default_tensor( + np.array( + [ + [ + [0, 0], + [0, 2], + [2, 0], + [2, 2], + ] * 3 + ])) + loss = self.loop_form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + self.assertGreater(loss[0], 0, f"Loss is not greater than 0") + + unsat_path = path.at[0, -4].set([0, 2]) # Make the last point unsatisfiable + loss = self.loop_form.eval(jax.numpy.tile(unsat_path, (num_tiles, 1, 1))) + self.assertLess(loss[0], 0, f"Loss is not less than 0 for unsat path") + + def test_stlpy_solver(self): + x_0 = np.array([0, 0]) + solver = StlpySolver(space_dim=2) + total_time = 12 # For loop spec + stlpy_form = self.loop_form.get_stlpy_form() + path, info = solver.solve_stlpy_formula(stlpy_form, x0=x_0, total_time=total_time) + + num_tiles = 4 + loss = self.loop_form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(loss[0], 0, f"STLPY solved path loss is not greater than 0") if __name__ == '__main__': unittest.main() From 1028bd4bef59de4b40dca6e2c3b9ae29272cc790 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Sun, 2 Jun 2024 23:44:17 -0400 Subject: [PATCH 12/22] fix(stlpy): add tests and fix stlpy conversion --- src/ds/stl.py | 17 +++++++++-------- src/ds/stl_jax.py | 6 +++--- tests/test_stl_jax.py | 18 ++++++++++++------ 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/ds/stl.py b/src/ds/stl.py index 37108c0..fda4cef 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -1,19 +1,20 @@ -import gurobipy as gp import io -import numpy as np import time -import torch from abc import abstractmethod from collections import deque from contextlib import redirect_stdout +from typing import TypeVar, Tuple + +import gurobipy as gp +import numpy as np +import torch from gurobipy import GRB from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree from stlpy.systems import LinearSystem from torch import Tensor from torch.nn.functional import softmax -from typing import TypeVar, Tuple -from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, set_hardness, outside_rectangle_formula, \ +from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, outside_rectangle_formula, \ inside_rectangle_formula with redirect_stdout(io.StringIO()): @@ -726,16 +727,16 @@ def _convert_implies(self, ast): def _convert_eventually(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.eventually(ast[2], ast[3]) + return sub_form.eventually(ast[2], ast[3] - 1) def _convert_always(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.always(ast[2], ast[3]) + return sub_form.always(ast[2], ast[3] - 1) def _convert_until(self, ast): sub_form_1 = self._to_stlpy(ast[1]) sub_form_2 = self._to_stlpy(ast[2]) - return sub_form_1.until(sub_form_2, ast[3], ast[4]) + return sub_form_1.until(sub_form_2, ast[3], ast[4] - 1) @staticmethod def _is_leaf(ast: AST): diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 2d9522b..3713745 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -563,16 +563,16 @@ def _convert_implies(self, ast): def _convert_eventually(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.eventually(ast[2], ast[3]) + return sub_form.eventually(ast[2], ast[3] - 1) def _convert_always(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.always(ast[2], ast[3]) + return sub_form.always(ast[2], ast[3] - 1) def _convert_until(self, ast): sub_form_1 = self._to_stlpy(ast[1]) sub_form_2 = self._to_stlpy(ast[2]) - return sub_form_1.until(sub_form_2, ast[3], ast[4]) + return sub_form_1.until(sub_form_2, ast[3], ast[4] - 1) @staticmethod def _is_leaf(ast: AST): diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index cfdfe01..38758ca 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -29,6 +29,8 @@ def setUp(self): # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 self.form = (self.goal_1.eventually(0, 5) & self.goal_2.eventually(0, 5)).always(0, 8) self.loop_form = (self.goal_1.eventually(0, 4) & self.goal_2.eventually(0, 4)).always(0, 8) + self.cover_form = self.goal_1.eventually(0, 12) & self.goal_2.eventually(0, 12) + self.seq_form = self.goal_1.eventually(0, 6) & self.goal_2.eventually(6, 12) def test_run(self): # TODO: Study jit decorator and see optimizations @@ -103,15 +105,19 @@ def test_loop(self): self.assertLess(loss[0], 0, f"Loss is not less than 0 for unsat path") def test_stlpy_solver(self): + """Test the stlpy solver with different forms of STL formulas""" x_0 = np.array([0, 0]) solver = StlpySolver(space_dim=2) - total_time = 12 # For loop spec - stlpy_form = self.loop_form.get_stlpy_form() - path, info = solver.solve_stlpy_formula(stlpy_form, x0=x_0, total_time=total_time) + total_time = 12 # Common total time for all formulas - num_tiles = 4 - loss = self.loop_form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles - self.assertGreater(loss[0], 0, f"STLPY solved path loss is not greater than 0") + for form in [self.loop_form, self.cover_form, self.seq_form]: + + stlpy_form = form.get_stlpy_form() + path, info = solver.solve_stlpy_formula(stlpy_form, x0=x_0, total_time=total_time) + + num_tiles = 4 + loss = form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(loss[0], 0, f"STLPY solved path loss is not greater than 0 for {form}") if __name__ == '__main__': unittest.main() From 51ee629786b211d0d39a5b726003b170d574804b Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Tue, 4 Jun 2024 21:47:49 -0400 Subject: [PATCH 13/22] feature(async): add vanish on end, show spec in pgf plot --- src/ds/stl_jax.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 3713745..bb871ae 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -26,7 +26,7 @@ # Replace with JAX import jax.numpy as jnp - +import re class PredicateBase: def __init__(self, name: str): @@ -658,6 +658,25 @@ def push_stack(ast): self.expr_repr = expr return expr + def latex_repr(self): + repr = self.__repr__() + def replace_special_chars(match): + return { + "~": r"\neg", + "&": r"\land", + "|": r"\lor", + "->": r"\rightarrow", + "G": r"\Box", + "F": r"\Diamond", + "U": r"U", + }[match.group(0)] + replaced_symb = re.sub(r"~|&|\||->|G|F|U", replace_special_chars, repr) + # replace any [a, b] with _{[a,b]} + replaced_symb = re.sub(r"\[(\d+), (\d+)\]", r"_{[\1,\2]}", replaced_symb) + # replace goal_0, goal_1, ... with A, B, C, ... + replaced_symb = re.sub(r"goal_(\d+)", lambda x: chr(ord('A') + int(x.group(1))), replaced_symb) + return replaced_symb + def get_all_predicates(self): all_preds = [] queue = deque([self.ast]) From f181b473ed1b0486bd78320e91da13d16ba6c2ae Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Thu, 8 Aug 2024 11:38:55 -0400 Subject: [PATCH 14/22] feature(stl): Working on JAX Until --- src/ds/stl_jax.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index bb871ae..047cb9b 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -488,11 +488,22 @@ def _eval_until( ) # mask condition, once condition > 0 (after until True), # the right sequence is no longer considered - cond = (till_pred > 0).int() + cond = (till_pred > 0).astype(int) index = jnp.argmax(cond, axis=-1) - for i in range(cond.shape[0]): - cond[i, index[i]:] = 1.0 - cond = ~cond.bool() + + batch_size, seq_len = cond.shape + row_indices = jnp.arange(batch_size)[:, None] + col_indices = jnp.arange(seq_len) + + mask = col_indices >= index[:, None] + + # Apply the mask + cond = mask.astype(int) + cond = ~cond.astype(bool) + + # for i in range(cond.shape[0]): + # cond[i, index[i]:] = 1.0 + # cond = ~cond.astype(bool) till_pred = jnp.where(cond, till_pred, ds_utils.default_tensor(1)) if self._is_leaf(sub_form1): From dd4698604979e5ba08363dea225059b610ebf96f Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Fri, 9 Aug 2024 14:09:13 -0400 Subject: [PATCH 15/22] feature(stl): add single plan no formation --- src/ds/stl_jax.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 047cb9b..71d1e56 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -28,6 +28,7 @@ import jax.numpy as jnp import re + class PredicateBase: def __init__(self, name: str): self.name = name @@ -51,6 +52,10 @@ def get_stlpy_form(self) -> STLTree: def __str__(self) -> str: return self.name + def __lt__(self, other: "PredicateBase") -> bool: + """Sort predicates by name.""" + return self.name < other.name + class RectReachPredicate(PredicateBase): """ @@ -671,6 +676,7 @@ def push_stack(ast): def latex_repr(self): repr = self.__repr__() + def replace_special_chars(match): return { "~": r"\neg", @@ -681,6 +687,7 @@ def replace_special_chars(match): "F": r"\Diamond", "U": r"U", }[match.group(0)] + replaced_symb = re.sub(r"~|&|\||->|G|F|U", replace_special_chars, repr) # replace any [a, b] with _{[a,b]} replaced_symb = re.sub(r"\[(\d+), (\d+)\]", r"_{[\1,\2]}", replaced_symb) From 611111425b2db8d35cb340c025930202abd9c9b7 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Sat, 12 Oct 2024 17:16:42 -0400 Subject: [PATCH 16/22] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1d72263..bf4f726 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Connect differentiable components with logical operators. ## Install ```bash -pip install git+https://github.com/ZikangXiong/diff-spec.git +pip install git+https://github.com/jeappen/diff-spec.git@feature/jax ``` ## First Order Logic From d89e4c6c330e59b62bab9c195fb0a683d0b2d4dd Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Mon, 14 Oct 2024 15:36:24 -0400 Subject: [PATCH 17/22] fix(ja): towards STL PyTrees compatible with JAX (helpful for vmapping with JAX) --- src/ds/stl_jax.py | 99 +++++++++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 41 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 71d1e56..eeddf5a 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -1,14 +1,14 @@ +from collections import deque + import importlib import io +import numpy as np import os from abc import abstractmethod -from collections import deque from contextlib import redirect_stdout -from typing import TypeVar - -import numpy as np from jax.nn import softmax from stlpy.STL import LinearPredicate as baseLinearPredicate, STLTree +from typing import TypeVar, NamedTuple os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes import ds.utils as ds_utils @@ -29,10 +29,8 @@ import re -class PredicateBase: - def __init__(self, name: str): - self.name = name - self.logger = logging.getLogger(__name__) +class PredicateBase(NamedTuple): + name: str def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: return self.eval_whole_path(path, t, t + 1)[:, 0] @@ -57,25 +55,50 @@ def __lt__(self, other: "PredicateBase") -> bool: return self.name < other.name -class RectReachPredicate(PredicateBase): - """ - Rectangle reachability predicate +class RectangularPredicate(NamedTuple): """ - - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5): - """ - :param cent: center of the rectangle - :param size: bound of the rectangle - :param name: name of the predicate + Rectangle reachability predicate """ - super().__init__(name) - self.cent = cent - self.size = size - self.cent_tensor = ds_utils.default_tensor(cent) - self.size_tensor = ds_utils.default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative - self.logger.info(f"shrink factor: {shrink_factor}") + cent: np.ndarray + size: np.ndarray + name: str + shrink_factor: float = 1.0 # shrink the rectangle to make it more conservative (for stlpy) + + @property + def size_tensor(self): + return ds_utils.default_tensor(self.size) + + @property + def cent_tensor(self): + return ds_utils.default_tensor(self.cent) + + def __hash__(self): + return hash((self.cent, self.size)) + + def __eq__(self, other): + if not isinstance(other, RectReachPredicate): + return False + # return self.cent == other.cent and self.size == other.size + # Above using float difference + return np.allclose(self.cent, other.cent) and np.allclose(self.size, other.size) + + def __str__(self) -> str: + return self.name + + def __lt__(self, other: "RectangularPredicate") -> bool: + """Sort predicates by name.""" + return self.name < other.name + + +# PREDICATE_FORM = TypeVar("PREDICATE_FORM", RectangularPredicate, PredicateBase) +PREDICATE_TYPES = (RectangularPredicate, PredicateBase) + + +class RectReachPredicate(RectangularPredicate): + """ + Rectangle reachability predicate + """ def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None @@ -97,24 +120,11 @@ def get_stlpy_form(self) -> STLTree: return inside_npy(bounds, 0, 1, 2, self.name) -class RectAvoidPredicate(PredicateBase): +class RectAvoidPredicate(RectangularPredicate): """ Rectangle avoidance predicate """ - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): - """ - :param cent: center of the rectangle - :param size: bound of the rectangle - :param name: name of the predicate - """ - super().__init__(name) - self.cent = cent - self.size = size - - self.cent_tensor = ds_utils.default_tensor(cent) - self.size_tensor = ds_utils.default_tensor(size) - def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: @@ -592,6 +602,10 @@ def _convert_until(self, ast): @staticmethod def _is_leaf(ast: AST): + # Check is type PREDICATE_FORM + for pred_type in PREDICATE_TYPES: + if isinstance(ast, pred_type): + return True return issubclass(type(ast), PredicateBase) def _tensor_min(self, tensor: jnp.array, axis=-1) -> jnp.array: @@ -611,10 +625,15 @@ def __repr__(self): if self.expr_repr is not None: return self.expr_repr + expr = self._extract_repr() + + self.expr_repr = expr + return expr + + def _extract_repr(self): single_operators = ("~", "G", "F") binary_operators = ("&", "|", "->", "U") time_bounded_operators = ("G", "F", "U") - # traverse ast operator_stack = [self.ast] expr = "" @@ -670,8 +689,6 @@ def push_stack(ast): push_stack("(") else: push_stack(cur[1]) - - self.expr_repr = expr return expr def latex_repr(self): From d3e1de62fefde81d339ea72f173f999720975ec3 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Mon, 14 Oct 2024 20:50:05 -0400 Subject: [PATCH 18/22] feature(jax): Allows multiple goal (in theory) --- src/ds/stl_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index eeddf5a..d1c49fb 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -74,10 +74,10 @@ def cent_tensor(self): return ds_utils.default_tensor(self.cent) def __hash__(self): - return hash((self.cent, self.size)) + return hash(f"{self.cent},{self.size}") def __eq__(self, other): - if not isinstance(other, RectReachPredicate): + if not isinstance(other, RectangularPredicate): return False # return self.cent == other.cent and self.size == other.size # Above using float difference From f34175866a49c778c8810158ba2c81f26582eeb3 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Mon, 14 Oct 2024 23:49:23 -0400 Subject: [PATCH 19/22] fix(jax) : temp fix for RectPredicate not inheriting PredicateBase --- src/ds/stl_jax.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index d1c49fb..4e02754 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -73,6 +73,21 @@ def size_tensor(self): def cent_tensor(self): return ds_utils.default_tensor(self.cent) + def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: + return self.eval_whole_path(path, t, t + 1)[:, 0] + + @abstractmethod + def eval_whole_path( + self, path: jnp.ndarray, start_t: int = 0, end_t: int = None + ) -> jnp.ndarray: + """Stick to JAX when possible.""" + raise NotImplementedError + + @abstractmethod + def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" + raise NotImplementedError + def __hash__(self): return hash(f"{self.cent},{self.size}") From 960a353b4f5c78b95683591e1a6d5ec950485203 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Wed, 16 Oct 2024 12:06:35 -0400 Subject: [PATCH 20/22] feature(jax): add rich repr for center of pred --- src/ds/stl_jax.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index 4e02754..0a3a54d 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -105,6 +105,10 @@ def __lt__(self, other: "RectangularPredicate") -> bool: """Sort predicates by name.""" return self.name < other.name + def __rich_repr__(self): + # Assumes that size is common and not important + yield f"{self.cent}" + # PREDICATE_FORM = TypeVar("PREDICATE_FORM", RectangularPredicate, PredicateBase) PREDICATE_TYPES = (RectangularPredicate, PredicateBase) @@ -645,7 +649,7 @@ def __repr__(self): self.expr_repr = expr return expr - def _extract_repr(self): + def _extract_repr(self, print_rich=False): single_operators = ("~", "G", "F") binary_operators = ("&", "|", "->", "U") time_bounded_operators = ("G", "F", "U") @@ -663,7 +667,10 @@ def push_stack(ast): while operator_stack: cur = operator_stack.pop() if self._is_leaf(cur): - expr += cur.__str__() + if print_rich: + expr += f"({str(next(cur.__rich_repr__()))})" + else: + expr += cur.__str__() elif isinstance(cur, str): if cur == "(" or cur == ")": expr += cur From 6922e51d2549ac56242602b7816cd8aa112651c3 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Thu, 31 Oct 2024 18:47:10 -0400 Subject: [PATCH 21/22] feature(jax): add tests towards fixing avoid not working for differentiable approaches (GCBF+ etc) --- examples/stl/differentiability.py | 100 ++++++++++++++++++++---------- src/ds/stl.py | 15 +++-- tests/test_stl.py | 20 +++--- tests/test_stl_jax.py | 26 +++++--- 4 files changed, 109 insertions(+), 52 deletions(-) diff --git a/examples/stl/differentiability.py b/examples/stl/differentiability.py index 896849b..44707d5 100644 --- a/examples/stl/differentiability.py +++ b/examples/stl/differentiability.py @@ -1,10 +1,9 @@ # %% -import os - +import importlib import matplotlib.pyplot as plt import numpy as np import optax -import importlib +import os import ds.utils as ds_utils @@ -12,11 +11,13 @@ if os.environ.get("DIFF_STL_BACKEND") == "jax": print("Using JAX backend") from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate + importlib.reload(ds_utils) # Reload the module to reset the backend import jax else: print("Using PyTorch backend") from ds.stl import STL, RectAvoidPredicate, RectReachPredicate + importlib.reload(ds_utils) # Reload the module to reset the backend import torch from torch.optim import Adam @@ -69,24 +70,39 @@ def eval_reach_avoid(mute=False): [1, 1], [1, 1], ], + [ + [9, 9], + [3, 2], + [7, 7], + [6, 6], + [5, 5], + [4, 4], + [3, 3], + [2, 2], + [1, 1], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ] ] ) ) # eval the formula, default at time 0 - res1 = form.eval(path=path_1) + res1 = form.eval(path=path_1) # (+,-,-) if not mute: print("eval result at time 0: ", res1) # eval the formula at time 2 - res2 = form.eval(path=path_1, t=2) + res2 = form.eval(path=path_1, t=2) # (+,-,+) if not mute: print("eval result at time 2: ", res2) return res1, res2 -def backward(mute=True): +def backward(avoid_spec=False, mute=True): """ Planning with gradient descent """ @@ -95,34 +111,50 @@ def backward(mute=True): # goal_1 is a rectangle area centered in [0, 0] with width and height 1 goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) # goal_2 is a rectangle area centered in [2, 2] with width and height 1 - goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + goal_2 = STL(RectReachPredicate(np.array([3, 3]), np.array([1, 1]), "goal_2")) + # goal_2 is a rectangle area centered in [1, 1] with width and height 1 + avoid_region = STL(RectAvoidPredicate(np.array([1, 1]), np.array([1, 1]), "avoid_region")) + avoid_region2 = STL(RectAvoidPredicate(np.array([2, 2]), np.array([1, 1]), "avoid_region2")) + avoid_region_goal1 = STL(RectAvoidPredicate(np.array([0, 0]), np.array([1, 1]), "avoid_region_goal1")) + avoid_region_goal2 = STL(RectAvoidPredicate(np.array([3, 3]), np.array([1, 1]), "avoid_region_goal2")) + end_time = 13 + + if avoid_spec: + print("cover while avoiding avoid_region") + # NOTE: Cover different just alternates between goal_1 and goal_2 + form = goal_2.eventually(0, end_time) & goal_1.eventually(0, end_time) \ + & avoid_region.always(0, end_time) & avoid_region2.always(0, end_time) \ + & avoid_region_goal1.always(end_time // 2, end_time) & avoid_region_goal2.always(0, end_time // 2) + else: + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) - # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 - # and that holds always in 0 to 8 - # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 - form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) - path = ds_utils.default_tensor( - np.array( + np_path = np.array( + [ [ - [ - [1, 0], - [1, 0], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [1, 0], - [1, 0], - ], - ] - ) + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] ) + + random_like = np.random.rand(*np_path.shape) + + path = ds_utils.default_tensor(random_like) loss = None lr = 0.1 num_iterations = 1000 @@ -135,7 +167,7 @@ def backward(mute=True): @jax.jit def train_step(params, solver_state): # Performs a one step update. - (loss), grad = jax.value_and_grad(lambda x : -form.eval(x).mean())( + (loss), grad = jax.value_and_grad(lambda x: -form.eval(x).mean())( params ) updates, solver_state = solver.update(grad, solver_state) @@ -147,12 +179,14 @@ def train_step(params, solver_state): path, var_solver_state ) - loss = form.eval(path) + loss = train_loss else: # PyTorch backend (slower when num_iterations is high) path.requires_grad = True opt = Adam(params=[path], lr=lr) + # ds_utils.HARDNESS = 3.0 + for _ in range(num_iterations): loss = -torch.mean(form.eval(path)) opt.zero_grad() diff --git a/src/ds/stl.py b/src/ds/stl.py index fda4cef..94f51dc 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -1,18 +1,18 @@ -import io -import time -from abc import abstractmethod from collections import deque -from contextlib import redirect_stdout -from typing import TypeVar, Tuple import gurobipy as gp +import io import numpy as np +import time import torch +from abc import abstractmethod +from contextlib import redirect_stdout from gurobipy import GRB from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree from stlpy.systems import LinearSystem from torch import Tensor from torch.nn.functional import softmax +from typing import TypeVar, Tuple from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, outside_rectangle_formula, \ inside_rectangle_formula @@ -386,6 +386,11 @@ def eval_whole_path( ) -> Tensor: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] + # NOTE: This is the soft version of the predicate + # res_soft = STL(None)._tensor_min( + # self.size_tensor / 2 - torch.abs(eval_path - self.cent_tensor), dim=-1 + # ) + res = torch.min( self.size_tensor / 2 - torch.abs(eval_path - self.cent_tensor), dim=-1 )[0] diff --git a/tests/test_stl.py b/tests/test_stl.py index 6e7dec8..0af39b4 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -1,9 +1,9 @@ import importlib +import numpy as np import os +import torch import unittest -import numpy as np - import ds.utils as ds_utils import examples.stl.differentiability as stl_diff_examples from ds.stl import STL, RectReachPredicate @@ -23,15 +23,21 @@ def test_run(self): # Fair test with jax res = stl_diff_examples.eval_reach_avoid(mute=True) final_result.append(res) + # Match expected output + assert ((res[0] > 0) == torch.tensor([True, False, False])).all() + assert ((res[1] > 0) == torch.tensor([True, False, True])).all() - print(final_result) + # print(final_result) # Test differentiability - path = stl_diff_examples.backward() - print(path) + path, loss = stl_diff_examples.backward() + print('Path', path) + assert loss < 0 # Loss should be less than 0 to satisfy the formula - # - # self.assertEqual(True, False) # add assertion here + def test_avoid_backward(self): + path, loss = stl_diff_examples.backward(avoid_spec=True) + print('AvoidPath', loss, path) + assert loss < 0 # Loss should be less than 0 to satisfy the formula def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types""" diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index 38758ca..f992c81 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -1,9 +1,9 @@ import importlib -import os -import unittest - import jax +import jax.numpy as jnp import numpy as np +import os +import unittest from jax import jit os.environ["DIFF_STL_BACKEND"] = "jax" # So ds_utils does not require torch @@ -13,6 +13,8 @@ from ds.stl_jax import STL, RectReachPredicate from ds.stl import StlpySolver + + class TestJAXExamples(unittest.TestCase): def setUp(self): @@ -41,18 +43,28 @@ def test_run(self): # Magic of jax res = jit(stl_diff_examples.eval_reach_avoid)() final_result.append(res) + # Match expected output + assert jnp.all((res[0] > 0) == jnp.array([True, False, False])) + assert jnp.all((res[1] > 0) == jnp.array([True, False, True])) print(final_result) # Test differentiability - path = stl_diff_examples.backward() - print(path) + path, loss = stl_diff_examples.backward() + print('Path', path) + assert loss < 0 # Loss should be less than 0 to satisfy the formula # (jax.lax.fori_loop(0, 1000, lambda i, _: jit(eval_reach_avoid)(), None)).block_until_ready() # for _ in range(1000): # eval_reach_avoid() # # self.assertEqual(True, False) # add assertion here + def test_avoid_backward(self): + + path, loss = stl_diff_examples.backward(avoid_spec=True) + print('AvoidPath', loss, path) + assert loss < 0 # Loss should be less than 0 to satisfy the formula + def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types""" path = ds_utils.default_tensor( @@ -108,10 +120,9 @@ def test_stlpy_solver(self): """Test the stlpy solver with different forms of STL formulas""" x_0 = np.array([0, 0]) solver = StlpySolver(space_dim=2) - total_time = 12 # Common total time for all formulas + total_time = 12 # Common total time for all formulas for form in [self.loop_form, self.cover_form, self.seq_form]: - stlpy_form = form.get_stlpy_form() path, info = solver.solve_stlpy_formula(stlpy_form, x0=x_0, total_time=total_time) @@ -119,5 +130,6 @@ def test_stlpy_solver(self): loss = form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles self.assertGreater(loss[0], 0, f"STLPY solved path loss is not greater than 0 for {form}") + if __name__ == '__main__': unittest.main() From 89b963c89a3906631aaf0921cb166271edeb49c8 Mon Sep 17 00:00:00 2001 From: Joe Eappen Date: Fri, 1 Nov 2024 15:01:33 -0400 Subject: [PATCH 22/22] fix(jax): add tolerance value to tests --- tests/test_stl.py | 5 +++-- tests/test_stl_jax.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_stl.py b/tests/test_stl.py index 0af39b4..f90e7ed 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -8,6 +8,7 @@ import examples.stl.differentiability as stl_diff_examples from ds.stl import STL, RectReachPredicate +TEST_TOLERANCE = 1e-3 # Small number close to 0 class TestExamples(unittest.TestCase): @@ -32,12 +33,12 @@ def test_run(self): # Test differentiability path, loss = stl_diff_examples.backward() print('Path', path) - assert loss < 0 # Loss should be less than 0 to satisfy the formula + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula def test_avoid_backward(self): path, loss = stl_diff_examples.backward(avoid_spec=True) print('AvoidPath', loss, path) - assert loss < 0 # Loss should be less than 0 to satisfy the formula + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types""" diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index f992c81..c0160d6 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -14,6 +14,7 @@ from ds.stl import StlpySolver +TEST_TOLERANCE = 1e-3 # Small number close to 0 class TestJAXExamples(unittest.TestCase): @@ -52,7 +53,7 @@ def test_run(self): # Test differentiability path, loss = stl_diff_examples.backward() print('Path', path) - assert loss < 0 # Loss should be less than 0 to satisfy the formula + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula # (jax.lax.fori_loop(0, 1000, lambda i, _: jit(eval_reach_avoid)(), None)).block_until_ready() # for _ in range(1000): # eval_reach_avoid() @@ -63,7 +64,7 @@ def test_avoid_backward(self): path, loss = stl_diff_examples.backward(avoid_spec=True) print('AvoidPath', loss, path) - assert loss < 0 # Loss should be less than 0 to satisfy the formula + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula def test_evaluations(self, num_tiles=3): """Run simple evaluations to test shapes and types"""