diff --git a/.gitignore b/.gitignore index 661341a..31eeb4d 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/.idea/diff-spec.iml +/.idea/modules.xml +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/workspace.xml diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bab28c..9c847b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +#### 2024-03-21 + +##### Bug Fixes + +* **jax:** + * fix evaluation bug in predicates, jnp.min and torch.min had different semantics. Fix examples to allow different + batch sizes as input to STL.eval (885e4d3c) + * add evaluation shape test (which is failing). TODO: Debug (47e75ae5) + * rename os env, reload module to make tests work together (18cdc0ec) + * Using the optax (25687960) + +##### Tests + +* **jax:** basic jit jax optimization (a930a76e) + #### 2024-02-24 ##### Tests diff --git a/README.md b/README.md index fba7348..bf4f726 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Connect differentiable components with logical operators. ## Install ```bash -pip install git+https://github.com/ZikangXiong/diff-spec.git +pip install git+https://github.com/jeappen/diff-spec.git@feature/jax ``` ## First Order Logic @@ -67,12 +67,12 @@ Probability temporal logic is an ongoing work integrating probability and random If you are using JAX, you can use the JAX backend (stl_jax) and gain immense speedups in many cases. -First set the backend to JAX: +First set the backend to JAX using Environment Variables for our utility functions: ```python import os -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX +os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX (if unset or any other value uses the PyTorch backend) ``` Then you can use the JAX backend to optimize the inputs to satisfy the formula. diff --git a/examples/stl/differentiability.py b/examples/stl/differentiability.py index 10190d1..44707d5 100644 --- a/examples/stl/differentiability.py +++ b/examples/stl/differentiability.py @@ -1,20 +1,24 @@ # %% -import os - +import importlib import matplotlib.pyplot as plt import numpy as np import optax +import os + +import ds.utils as ds_utils # if JAX_BACKEND is set the import will be from jax.numpy -if os.environ.get("JAX_STL_BACKEND") == "jax": +if os.environ.get("DIFF_STL_BACKEND") == "jax": print("Using JAX backend") from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate - from ds.utils import default_tensor + + importlib.reload(ds_utils) # Reload the module to reset the backend import jax else: print("Using PyTorch backend") from ds.stl import STL, RectAvoidPredicate, RectReachPredicate - from ds.utils import default_tensor + + importlib.reload(ds_utils) # Reload the module to reset the backend import torch from torch.optim import Adam @@ -33,7 +37,7 @@ def eval_reach_avoid(mute=False): form = goal.eventually(0, 10) & obs.always(0, 10) # Define 2 initial paths in batch - path_1 = default_tensor( + path_1 = ds_utils.default_tensor( np.array( [ [ @@ -66,24 +70,39 @@ def eval_reach_avoid(mute=False): [1, 1], [1, 1], ], + [ + [9, 9], + [3, 2], + [7, 7], + [6, 6], + [5, 5], + [4, 4], + [3, 3], + [2, 2], + [1, 1], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ] ] ) ) # eval the formula, default at time 0 - res1 = form.eval(path=path_1) + res1 = form.eval(path=path_1) # (+,-,-) if not mute: print("eval result at time 0: ", res1) # eval the formula at time 2 - res2 = form.eval(path=path_1, t=2) + res2 = form.eval(path=path_1, t=2) # (+,-,+) if not mute: print("eval result at time 2: ", res2) return res1, res2 -def backward(mute=True): +def backward(avoid_spec=False, mute=True): """ Planning with gradient descent """ @@ -92,39 +111,55 @@ def backward(mute=True): # goal_1 is a rectangle area centered in [0, 0] with width and height 1 goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) # goal_2 is a rectangle area centered in [2, 2] with width and height 1 - goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + goal_2 = STL(RectReachPredicate(np.array([3, 3]), np.array([1, 1]), "goal_2")) + # goal_2 is a rectangle area centered in [1, 1] with width and height 1 + avoid_region = STL(RectAvoidPredicate(np.array([1, 1]), np.array([1, 1]), "avoid_region")) + avoid_region2 = STL(RectAvoidPredicate(np.array([2, 2]), np.array([1, 1]), "avoid_region2")) + avoid_region_goal1 = STL(RectAvoidPredicate(np.array([0, 0]), np.array([1, 1]), "avoid_region_goal1")) + avoid_region_goal2 = STL(RectAvoidPredicate(np.array([3, 3]), np.array([1, 1]), "avoid_region_goal2")) + end_time = 13 + + if avoid_spec: + print("cover while avoiding avoid_region") + # NOTE: Cover different just alternates between goal_1 and goal_2 + form = goal_2.eventually(0, end_time) & goal_1.eventually(0, end_time) \ + & avoid_region.always(0, end_time) & avoid_region2.always(0, end_time) \ + & avoid_region_goal1.always(end_time // 2, end_time) & avoid_region_goal2.always(0, end_time // 2) + else: + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) - # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 - # and that holds always in 0 to 8 - # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 - form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) - path = default_tensor( - np.array( + np_path = np.array( + [ [ - [ - [1, 0], - [1, 0], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [1, 0], - [1, 0], - ], - ] - ) + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] ) + + random_like = np.random.rand(*np_path.shape) + + path = ds_utils.default_tensor(random_like) loss = None lr = 0.1 num_iterations = 1000 - if os.environ.get("JAX_STL_BACKEND") == "jax": + if os.environ.get("DIFF_STL_BACKEND") == "jax": solver = optax.adam(lr) var_solver_state = solver.init(path) @@ -132,10 +167,10 @@ def backward(mute=True): @jax.jit def train_step(params, solver_state): # Performs a one step update. - (loss), grad = jax.value_and_grad(form.eval)( + (loss), grad = jax.value_and_grad(lambda x: -form.eval(x).mean())( params ) - updates, solver_state = solver.update(-grad, solver_state) + updates, solver_state = solver.update(grad, solver_state) params = optax.apply_updates(params, updates) return params, solver_state, loss @@ -144,12 +179,14 @@ def train_step(params, solver_state): path, var_solver_state ) - loss = form.eval(path) + loss = train_loss else: # PyTorch backend (slower when num_iterations is high) path.requires_grad = True opt = Adam(params=[path], lr=lr) + # ds_utils.HARDNESS = 3.0 + for _ in range(num_iterations): loss = -torch.mean(form.eval(path)) opt.zero_grad() diff --git a/src/ds/stl.py b/src/ds/stl.py index ec90766..94f51dc 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -1,54 +1,27 @@ -import io -import time -from abc import abstractmethod from collections import deque -from contextlib import contextmanager -from contextlib import redirect_stdout -from typing import TypeVar, Tuple import gurobipy as gp +import io import numpy as np +import time import torch +from abc import abstractmethod +from contextlib import redirect_stdout from gurobipy import GRB from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree from stlpy.systems import LinearSystem from torch import Tensor from torch.nn.functional import softmax +from typing import TypeVar, Tuple -from ds.utils import default_tensor +from ds.utils import default_tensor, colored, HARDNESS, IMPLIES_TRICK, outside_rectangle_formula, \ + inside_rectangle_formula with redirect_stdout(io.StringIO()): from stlpy.solvers.base import STLSolver import logging -COLORED = False -IMPLIES_TRICK = False -HARDNESS = 100.0 # Reduce hardness of softmax to propagate gradients more easily - - -@contextmanager -def set_hardness(hardness: float): - """Set the hardness of the softmax function for the duration of the context. - Useful for making evaluation strict while allowing gradients to pass through during training. - - :param hardness: hardness of the softmax function - :type hardness: float - """ - global HARDNESS - old_hardness = HARDNESS - HARDNESS = hardness - yield - HARDNESS = old_hardness - - -if COLORED: - from termcolor import colored -else: - - def colored(text, color): - return text - class GurobiMICPSolver(STLSolver): """ @@ -369,6 +342,7 @@ def solve_stlpy_formula( class PredicateBase: def __init__(self, name: str): self.name = name + self.logger = logging.getLogger(__name__) def eval_at_t(self, path: Tensor, t: int = 0) -> Tensor: return self.eval_whole_path(path, t, t + 1)[:, 0] @@ -404,14 +378,19 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: self.cent_tensor = default_tensor(cent) self.size_tensor = default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative - print(f"shrink factor: {shrink_factor}") + self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative in STLpy + self.logger.info(f"shrink factor: {shrink_factor}") def eval_whole_path( self, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] + # NOTE: This is the soft version of the predicate + # res_soft = STL(None)._tensor_min( + # self.size_tensor / 2 - torch.abs(eval_path - self.cent_tensor), dim=-1 + # ) + res = torch.min( self.size_tensor / 2 - torch.abs(eval_path - self.cent_tensor), dim=-1 )[0] @@ -461,110 +440,6 @@ def get_stlpy_form(self) -> STLTree: return outside_rectangle_formula(bounds, 0, 1, 2, self.name) -def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): - """ - Create an STL formula representing being inside a - rectangle with the given bounds: - - :: - - y2_max +-------------------+ - | | - | | - | | - y2_min +-------------------+ - y1_min y1_max - - :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing - the bounds of the rectangle. - :param y1_index: index of the first (``y1``) dimension - :param y2_index: index of the second (``y2``) dimension - :param d: dimension of the overall signal - :param name: (optional) string describing this formula - - :return inside_rectangle: An ``STLFormula`` specifying being inside the - rectangle at time zero. - """ - assert y1_index < d, "index must be less than signal dimension" - assert y2_index < d, "index must be less than signal dimension" - - # Unpack the bounds - y1_min, y1_max, y2_min, y2_max = bounds - - # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 - right = LinearPredicate(a1, y1_min) - left = LinearPredicate(-a1, -y1_max) - - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 - top = LinearPredicate(a2, y2_min) - bottom = LinearPredicate(-a2, -y2_max) - - # Take the conjuction across all the sides - inside_rectangle = right & left & top & bottom - - # set the names - if name is not None: - inside_rectangle.__str__ = lambda: name - inside_rectangle.__repr__ = lambda: name - - return inside_rectangle - - -def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): - """ - Create an STL formula representing being outside a - rectangle with the given bounds: - - :: - - y2_max +-------------------+ - | | - | | - | | - y2_min +-------------------+ - y1_min y1_max - - :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing - the bounds of the rectangle. - :param y1_index: index of the first (``y1``) dimension - :param y2_index: index of the second (``y2``) dimension - :param d: dimension of the overall signal - :param name: (optional) string describing this formula - - :return outside_rectangle: An ``STLFormula`` specifying being outside the - rectangle at time zero. - """ - assert y1_index < d, "index must be less than signal dimension" - assert y2_index < d, "index must be less than signal dimension" - - # Unpack the bounds - y1_min, y1_max, y2_min, y2_max = bounds - - # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 - right = LinearPredicate(a1, y1_max) - left = LinearPredicate(-a1, -y1_min) - - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 - top = LinearPredicate(a2, y2_max) - bottom = LinearPredicate(-a2, -y2_min) - - # Take the disjuction across all the sides - outside_rectangle = right | left | top | bottom - - # set the names - if name is not None: - outside_rectangle.__str__ = lambda: name - outside_rectangle.__repr__ = lambda: name - - return outside_rectangle - - AST = TypeVar("AST", list, PredicateBase) @@ -857,16 +732,16 @@ def _convert_implies(self, ast): def _convert_eventually(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.eventually(ast[2], ast[3]) + return sub_form.eventually(ast[2], ast[3] - 1) def _convert_always(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.always(ast[2], ast[3]) + return sub_form.always(ast[2], ast[3] - 1) def _convert_until(self, ast): sub_form_1 = self._to_stlpy(ast[1]) sub_form_2 = self._to_stlpy(ast[2]) - return sub_form_1.until(sub_form_2, ast[3], ast[4]) + return sub_form_1.until(sub_form_2, ast[3], ast[4] - 1) @staticmethod def _is_leaf(ast: AST): diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index d6494d3..0a3a54d 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -1,32 +1,36 @@ +from collections import deque + +import importlib import io +import numpy as np import os from abc import abstractmethod -from collections import deque from contextlib import redirect_stdout -from typing import TypeVar - -import jax -import numpy as np from jax.nn import softmax -from stlpy.STL import LinearPredicate, STLTree +from stlpy.STL import LinearPredicate as baseLinearPredicate, STLTree +from typing import TypeVar, NamedTuple + +os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes +import ds.utils as ds_utils -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes -from ds.utils import default_tensor +importlib.reload(ds_utils) # Reload the module to change the backend with redirect_stdout(io.StringIO()): pass import logging -from .stl import colored, HARDNESS, IMPLIES_TRICK, set_hardness +colored, HARDNESS, IMPLIES_TRICK, set_hardness = ds_utils.colored, ds_utils.HARDNESS, ds_utils.IMPLIES_TRICK, ds_utils.set_hardness +outside_npy = ds_utils.outside_rectangle_formula +inside_npy = ds_utils.inside_rectangle_formula # Replace with JAX import jax.numpy as jnp +import re -class PredicateBase: - def __init__(self, name: str): - self.name = name +class PredicateBase(NamedTuple): + name: str def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: return self.eval_whole_path(path, t, t + 1)[:, 0] @@ -35,88 +39,129 @@ def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: def eval_whole_path( self, path: jnp.ndarray, start_t: int = 0, end_t: int = None ) -> jnp.ndarray: + """Stick to JAX when possible.""" raise NotImplementedError @abstractmethod def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" raise NotImplementedError def __str__(self) -> str: return self.name + def __lt__(self, other: "PredicateBase") -> bool: + """Sort predicates by name.""" + return self.name < other.name -class RectReachPredicate(PredicateBase): - """ - Rectangle reachability predicate - """ - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5): - """ - :param cent: center of the rectangle - :param size: bound of the rectangle - :param name: name of the predicate +class RectangularPredicate(NamedTuple): + """ + Rectangle reachability predicate """ - super().__init__(name) - self.cent = cent - self.size = size - self.cent_tensor = default_tensor(cent) - self.size_tensor = default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative - print(f"shrink factor: {shrink_factor}") + cent: np.ndarray + size: np.ndarray + name: str + shrink_factor: float = 1.0 # shrink the rectangle to make it more conservative (for stlpy) + + @property + def size_tensor(self): + return ds_utils.default_tensor(self.size) + + @property + def cent_tensor(self): + return ds_utils.default_tensor(self.cent) + + def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: + return self.eval_whole_path(path, t, t + 1)[:, 0] + + @abstractmethod + def eval_whole_path( + self, path: jnp.ndarray, start_t: int = 0, end_t: int = None + ) -> jnp.ndarray: + """Stick to JAX when possible.""" + raise NotImplementedError + + @abstractmethod + def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" + raise NotImplementedError + + def __hash__(self): + return hash(f"{self.cent},{self.size}") + + def __eq__(self, other): + if not isinstance(other, RectangularPredicate): + return False + # return self.cent == other.cent and self.size == other.size + # Above using float difference + return np.allclose(self.cent, other.cent) and np.allclose(self.size, other.size) + + def __str__(self) -> str: + return self.name + + def __lt__(self, other: "RectangularPredicate") -> bool: + """Sort predicates by name.""" + return self.name < other.name + + def __rich_repr__(self): + # Assumes that size is common and not important + yield f"{self.cent}" + + +# PREDICATE_FORM = TypeVar("PREDICATE_FORM", RectangularPredicate, PredicateBase) +PREDICATE_TYPES = (RectangularPredicate, PredicateBase) + + +class RectReachPredicate(RectangularPredicate): + """ + Rectangle reachability predicate + """ def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: + """Stick to JAX when possible.""" assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] res = jnp.min( self.size_tensor / 2 - jnp.abs(eval_path - self.cent_tensor), axis=-1 - )[0] + ) return res def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" bounds = np.stack( [self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2] ).T.flatten() - return inside_rectangle_formula(bounds, 0, 1, 2, self.name) + return inside_npy(bounds, 0, 1, 2, self.name) -class RectAvoidPredicate(PredicateBase): +class RectAvoidPredicate(RectangularPredicate): """ Rectangle avoidance predicate """ - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): - """ - :param cent: center of the rectangle - :param size: bound of the rectangle - :param name: name of the predicate - """ - super().__init__(name) - self.cent = cent - self.size = size - - self.cent_tensor = default_tensor(cent) - self.size_tensor = default_tensor(size) - def eval_whole_path( self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: + """Stick to JAX when possible.""" assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] res = jnp.max( jnp.abs(eval_path - self.cent_tensor) - self.size_tensor / 2, axis=-1 - )[0] + ) return res def get_stlpy_form(self) -> STLTree: + """Use Numpy to ensure compatibility with STLpy.""" bounds = np.stack( [self.cent - self.size / 2, self.cent + self.size / 2] ).T.flatten() - return outside_rectangle_formula(bounds, 0, 1, 2, self.name) + return outside_npy(bounds, 0, 1, 2, self.name) def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): @@ -150,13 +195,13 @@ def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): y1_min, y1_max, y2_min, y2_max = bounds # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 + a1 = jnp.zeros((1, d)) + a1.at[:, y1_index].set(1) right = LinearPredicate(a1, y1_min) left = LinearPredicate(-a1, -y1_max) - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 + a2 = jnp.zeros((1, d)) + a2.at[:, y2_index].set(1) top = LinearPredicate(a2, y2_min) bottom = LinearPredicate(-a2, -y2_max) @@ -202,13 +247,13 @@ def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): y1_min, y1_max, y2_min, y2_max = bounds # Create predicates a*y >= b for each side of the rectangle - a1 = np.zeros((1, d)) - a1[:, y1_index] = 1 + a1 = jnp.zeros((1, d)) + a1.at[:, y1_index].set(1) right = LinearPredicate(a1, y1_max) left = LinearPredicate(-a1, -y1_min) - a2 = np.zeros((1, d)) - a2[:, y2_index] = 1 + a2 = jnp.zeros((1, d)) + a2.at[:, y2_index].set(1) top = LinearPredicate(a2, y2_max) bottom = LinearPredicate(-a2, -y2_min) @@ -223,12 +268,47 @@ def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): return outside_rectangle +class LinearPredicate(baseLinearPredicate): + """ + A linear STL predicate :math:`\pi` defined by + + .. math:: + + a^Ty_t - b \geq 0 + + where :math:`y_t \in \mathbb{R}^d` is the value of the signal + at a given timestep :math:`t`, :math:`a \in \mathbb{R}^d`, + and :math:`b \in \mathbb{R}`. + + :param a: a jax numpy array or list representing the vector :math:`a` + :param b: a list, jax numpy array, or scalar representing :math:`b` + :param name: (optional) a string used to identify this predicate. + """ + + def __init__(self, a, b, name=None): + # Convert provided constraints to numpy arrays + self.a = jnp.asarray(a).reshape((-1, 1)) + self.b = jnp.atleast_1d(b) + + # Some dimension-related sanity checks + assert (self.a.shape[1] == 1), "a must be of shape (d,1)" + assert (self.b.shape == (1,)), "b must be of shape (1,)" + + # Store the dimensionality of y_t + self.d = self.a.shape[0] + + # A unique string describing this predicate + self.name = name + + AST = TypeVar("AST", list, PredicateBase) class STL: """ Class for representing STL formulas. + + All methods are functionally pure with no side effects during execution. """ def __init__(self, ast: AST): @@ -288,6 +368,9 @@ def _get_end_time(self, ast: AST) -> int: """Get max time of the formula. Runs in O(n) time where n is the number of nodes. Runs once then memoizes.""" if self._is_leaf(ast): return 1 + if ast[0] == "G": + # Add end time from inner formula since always is unrolled + return ast[-1] + self._get_end_time(ast[1]) if ast[0] in self.sequence_operators: # The last two elements are the start and end times return ast[-1] @@ -439,12 +522,23 @@ def _eval_until( ) # mask condition, once condition > 0 (after until True), # the right sequence is no longer considered - cond = (till_pred > 0).int() + cond = (till_pred > 0).astype(int) index = jnp.argmax(cond, axis=-1) - for i in range(cond.shape[0]): - cond[i, index[i]:] = 1.0 - cond = ~cond.bool() - till_pred = jnp.where(cond, till_pred, default_tensor(1)) + + batch_size, seq_len = cond.shape + row_indices = jnp.arange(batch_size)[:, None] + col_indices = jnp.arange(seq_len) + + mask = col_indices >= index[:, None] + + # Apply the mask + cond = mask.astype(int) + cond = ~cond.astype(bool) + + # for i in range(cond.shape[0]): + # cond[i, index[i]:] = 1.0 + # cond = ~cond.astype(bool) + till_pred = jnp.where(cond, till_pred, ds_utils.default_tensor(1)) if self._is_leaf(sub_form1): res = sub_form1.eval_whole_path(path[:, start_t:end_t]) @@ -456,7 +550,7 @@ def _eval_until( ], axis=-1, ) - res = jnp.where(cond, res, default_tensor(-1)) + res = jnp.where(cond, res, ds_utils.default_tensor(-1)) # when cond < 0, res should always > 0 to be hold return self._tensor_min(-res * till_pred, axis=-1) @@ -514,19 +608,23 @@ def _convert_implies(self, ast): def _convert_eventually(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.eventually(ast[2], ast[3]) + return sub_form.eventually(ast[2], ast[3] - 1) def _convert_always(self, ast): sub_form = self._to_stlpy(ast[1]) - return sub_form.always(ast[2], ast[3]) + return sub_form.always(ast[2], ast[3] - 1) def _convert_until(self, ast): sub_form_1 = self._to_stlpy(ast[1]) sub_form_2 = self._to_stlpy(ast[2]) - return sub_form_1.until(sub_form_2, ast[3], ast[4]) + return sub_form_1.until(sub_form_2, ast[3], ast[4] - 1) @staticmethod def _is_leaf(ast: AST): + # Check is type PREDICATE_FORM + for pred_type in PREDICATE_TYPES: + if isinstance(ast, pred_type): + return True return issubclass(type(ast), PredicateBase) def _tensor_min(self, tensor: jnp.array, axis=-1) -> jnp.array: @@ -546,10 +644,15 @@ def __repr__(self): if self.expr_repr is not None: return self.expr_repr + expr = self._extract_repr() + + self.expr_repr = expr + return expr + + def _extract_repr(self, print_rich=False): single_operators = ("~", "G", "F") binary_operators = ("&", "|", "->", "U") time_bounded_operators = ("G", "F", "U") - # traverse ast operator_stack = [self.ast] expr = "" @@ -564,7 +667,10 @@ def push_stack(ast): while operator_stack: cur = operator_stack.pop() if self._is_leaf(cur): - expr += cur.__str__() + if print_rich: + expr += f"({str(next(cur.__rich_repr__()))})" + else: + expr += cur.__str__() elif isinstance(cur, str): if cur == "(" or cur == ")": expr += cur @@ -605,10 +711,29 @@ def push_stack(ast): push_stack("(") else: push_stack(cur[1]) - - self.expr_repr = expr return expr + def latex_repr(self): + repr = self.__repr__() + + def replace_special_chars(match): + return { + "~": r"\neg", + "&": r"\land", + "|": r"\lor", + "->": r"\rightarrow", + "G": r"\Box", + "F": r"\Diamond", + "U": r"U", + }[match.group(0)] + + replaced_symb = re.sub(r"~|&|\||->|G|F|U", replace_special_chars, repr) + # replace any [a, b] with _{[a,b]} + replaced_symb = re.sub(r"\[(\d+), (\d+)\]", r"_{[\1,\2]}", replaced_symb) + # replace goal_0, goal_1, ... with A, B, C, ... + replaced_symb = re.sub(r"goal_(\d+)", lambda x: chr(ord('A') + int(x.group(1))), replaced_symb) + return replaced_symb + def get_all_predicates(self): all_preds = [] queue = deque([self.ast]) diff --git a/src/ds/utils.py b/src/ds/utils.py index 88040ea..a84b97d 100644 --- a/src/ds/utils.py +++ b/src/ds/utils.py @@ -1,9 +1,143 @@ +import numpy as np import os +from contextlib import contextmanager + +COLORED = False +IMPLIES_TRICK = False +HARDNESS = 100.0 # Reduce hardness of softmax to propagate gradients more easily + + +@contextmanager +def set_hardness(hardness: float): + """Set the hardness of the softmax function for the duration of the context. + Useful for making evaluation strict while allowing gradients to pass through during training. + + :param hardness: hardness of the softmax function + :type hardness: float + """ + global HARDNESS + old_hardness = HARDNESS + HARDNESS = hardness + yield + HARDNESS = old_hardness + + +if COLORED: + from termcolor import colored +else: + + def colored(text, color): + return text + +from stlpy.STL import LinearPredicate, NonlinearPredicate, STLTree + + +def inside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): + """ + Create an STL formula representing being inside a + rectangle with the given bounds: + + :: + + y2_max +-------------------+ + | | + | | + | | + y2_min +-------------------+ + y1_min y1_max + + :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing + the bounds of the rectangle. + :param y1_index: index of the first (``y1``) dimension + :param y2_index: index of the second (``y2``) dimension + :param d: dimension of the overall signal + :param name: (optional) string describing this formula + + :return inside_rectangle: An ``STLFormula`` specifying being inside the + rectangle at time zero. + """ + assert y1_index < d, "index must be less than signal dimension" + assert y2_index < d, "index must be less than signal dimension" + + # Unpack the bounds + y1_min, y1_max, y2_min, y2_max = bounds + + # Create predicates a*y >= b for each side of the rectangle + a1 = np.zeros((1, d)) + a1[:, y1_index] = 1 + right = LinearPredicate(a1, y1_min) + left = LinearPredicate(-a1, -y1_max) + + a2 = np.zeros((1, d)) + a2[:, y2_index] = 1 + top = LinearPredicate(a2, y2_min) + bottom = LinearPredicate(-a2, -y2_max) + + # Take the conjuction across all the sides + inside_rectangle = right & left & top & bottom + + # set the names + if name is not None: + inside_rectangle.__str__ = lambda: name + inside_rectangle.__repr__ = lambda: name + + return inside_rectangle + + +def outside_rectangle_formula(bounds, y1_index, y2_index, d, name=None): + """ + Create an STL formula representing being outside a + rectangle with the given bounds: + + :: + + y2_max +-------------------+ + | | + | | + | | + y2_min +-------------------+ + y1_min y1_max + + :param bounds: Tuple ``(y1_min, y1_max, y2_min, y2_max)`` containing + the bounds of the rectangle. + :param y1_index: index of the first (``y1``) dimension + :param y2_index: index of the second (``y2``) dimension + :param d: dimension of the overall signal + :param name: (optional) string describing this formula + + :return outside_rectangle: An ``STLFormula`` specifying being outside the + rectangle at time zero. + """ + assert y1_index < d, "index must be less than signal dimension" + assert y2_index < d, "index must be less than signal dimension" + + # Unpack the bounds + y1_min, y1_max, y2_min, y2_max = bounds + + # Create predicates a*y >= b for each side of the rectangle + a1 = np.zeros((1, d)) + a1[:, y1_index] = 1 + right = LinearPredicate(a1, y1_max) + left = LinearPredicate(-a1, -y1_min) + + a2 = np.zeros((1, d)) + a2[:, y2_index] = 1 + top = LinearPredicate(a2, y2_max) + bottom = LinearPredicate(-a2, -y2_min) + + # Take the disjuction across all the sides + outside_rectangle = right | left | top | bottom + + # set the names + if name is not None: + outside_rectangle.__str__ = lambda: name + outside_rectangle.__repr__ = lambda: name + + return outside_rectangle -import numpy as np # if JAX_BACKEND is set the import will be from jax.numpy -if os.environ.get("JAX_STL_BACKEND") == "jax": +if os.environ.get("DIFF_STL_BACKEND") == "jax": # print("Using JAX backend") import jax diff --git a/tests/test_stl.py b/tests/test_stl.py index 27463c8..f90e7ed 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -1,26 +1,81 @@ +import importlib +import numpy as np +import os +import torch import unittest -from examples.stl.differentiability import eval_reach_avoid, backward +import ds.utils as ds_utils +import examples.stl.differentiability as stl_diff_examples +from ds.stl import STL, RectReachPredicate +TEST_TOLERANCE = 1e-3 # Small number close to 0 class TestExamples(unittest.TestCase): + def setUp(self): + os.environ["DIFF_STL_BACKEND"] = "" + importlib.reload(stl_diff_examples) # Reload the module to reset the backend + importlib.reload(ds_utils) # Reload the module to reset the backend + def test_run(self): # Test Eval final_result = [] for _ in range(1000): # Fair test with jax - res = eval_reach_avoid(mute=True) + res = stl_diff_examples.eval_reach_avoid(mute=True) final_result.append(res) + # Match expected output + assert ((res[0] > 0) == torch.tensor([True, False, False])).all() + assert ((res[1] > 0) == torch.tensor([True, False, True])).all() - print(final_result) + # print(final_result) # Test differentiability - path = backward() - print(path) + path, loss = stl_diff_examples.backward() + print('Path', path) + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula + + def test_avoid_backward(self): + path, loss = stl_diff_examples.backward(avoid_spec=True) + print('AvoidPath', loss, path) + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula + + def test_evaluations(self, num_tiles=3): + """Run simple evaluations to test shapes and types""" + goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) + # goal_2 is a rectangle area centered in [2, 2] with width and height 1 + goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) + path = ds_utils.default_tensor( + np.array( + [ + [ + [1, 0], + [1, 0], + [1, 0], + [0, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] + ) + ) - # - # self.assertEqual(True, False) # add assertion here + loss = form.eval(path.tile(num_tiles, 1, 1)) # Make a batch of size num_tiles + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") if __name__ == '__main__': diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index a27bfad..c0160d6 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -1,13 +1,40 @@ +import importlib +import jax +import jax.numpy as jnp +import numpy as np import os import unittest - -os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes -from examples.stl.differentiability import eval_reach_avoid, backward from jax import jit +os.environ["DIFF_STL_BACKEND"] = "jax" # So ds_utils does not require torch + +import ds.utils as ds_utils +import examples.stl.differentiability as stl_diff_examples +from ds.stl_jax import STL, RectReachPredicate + +from ds.stl import StlpySolver + +TEST_TOLERANCE = 1e-3 # Small number close to 0 class TestJAXExamples(unittest.TestCase): + def setUp(self): + os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes + importlib.reload(stl_diff_examples) # Reload the module to reset the backend + importlib.reload(ds_utils) # Reload the module to reset the backend + + self.goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) + # goal_2 is a rectangle area centered in [2, 2] with width and height 1 + self.goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) + + # form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 + # and that holds always in 0 to 8 + # In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 + self.form = (self.goal_1.eventually(0, 5) & self.goal_2.eventually(0, 5)).always(0, 8) + self.loop_form = (self.goal_1.eventually(0, 4) & self.goal_2.eventually(0, 4)).always(0, 8) + self.cover_form = self.goal_1.eventually(0, 12) & self.goal_2.eventually(0, 12) + self.seq_form = self.goal_1.eventually(0, 6) & self.goal_2.eventually(6, 12) + def test_run(self): # TODO: Study jit decorator and see optimizations # jit(eval_reach_avoid)() @@ -15,20 +42,95 @@ def test_run(self): final_result = [] for _ in range(1000): # Magic of jax - res = jit(eval_reach_avoid)() + res = jit(stl_diff_examples.eval_reach_avoid)() final_result.append(res) + # Match expected output + assert jnp.all((res[0] > 0) == jnp.array([True, False, False])) + assert jnp.all((res[1] > 0) == jnp.array([True, False, True])) print(final_result) # Test differentiability - path = backward() - print(path) + path, loss = stl_diff_examples.backward() + print('Path', path) + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula # (jax.lax.fori_loop(0, 1000, lambda i, _: jit(eval_reach_avoid)(), None)).block_until_ready() # for _ in range(1000): # eval_reach_avoid() # # self.assertEqual(True, False) # add assertion here + def test_avoid_backward(self): + + path, loss = stl_diff_examples.backward(avoid_spec=True) + print('AvoidPath', loss, path) + assert loss < TEST_TOLERANCE # Loss should be less than 0 to satisfy the formula + + def test_evaluations(self, num_tiles=3): + """Run simple evaluations to test shapes and types""" + path = ds_utils.default_tensor( + np.array( + [ + [ + [1, 0], + [1, 0], + [1, 0], + [0, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + ], + ] + ) + ) + + loss = self.form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + + def test_loop(self): + # Test loop spec + num_tiles = 4 + path = ds_utils.default_tensor( + np.array( + [ + [ + [0, 0], + [0, 2], + [2, 0], + [2, 2], + ] * 3 + ])) + loss = self.loop_form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(len(loss.shape), 0, f"Not returning correct shape") + self.assertEqual(loss.shape[0], num_tiles, f"Not returning {num_tiles} values") + self.assertGreater(loss[0], 0, f"Loss is not greater than 0") + + unsat_path = path.at[0, -4].set([0, 2]) # Make the last point unsatisfiable + loss = self.loop_form.eval(jax.numpy.tile(unsat_path, (num_tiles, 1, 1))) + self.assertLess(loss[0], 0, f"Loss is not less than 0 for unsat path") + + def test_stlpy_solver(self): + """Test the stlpy solver with different forms of STL formulas""" + x_0 = np.array([0, 0]) + solver = StlpySolver(space_dim=2) + total_time = 12 # Common total time for all formulas + + for form in [self.loop_form, self.cover_form, self.seq_form]: + stlpy_form = form.get_stlpy_form() + path, info = solver.solve_stlpy_formula(stlpy_form, x0=x_0, total_time=total_time) + + num_tiles = 4 + loss = form.eval(jax.numpy.tile(path, (num_tiles, 1, 1))) # Make a batch of size num_tiles + self.assertGreater(loss[0], 0, f"STLPY solved path loss is not greater than 0 for {form}") + if __name__ == '__main__': unittest.main()