diff --git a/examples/stl/differentiability.py b/examples/stl/differentiability.py index 10190d1..defe054 100644 --- a/examples/stl/differentiability.py +++ b/examples/stl/differentiability.py @@ -8,16 +8,18 @@ # if JAX_BACKEND is set the import will be from jax.numpy if os.environ.get("JAX_STL_BACKEND") == "jax": print("Using JAX backend") + import jax + from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate from ds.utils import default_tensor - import jax else: print("Using PyTorch backend") - from ds.stl import STL, RectAvoidPredicate, RectReachPredicate - from ds.utils import default_tensor import torch from torch.optim import Adam + from ds.stl import STL, RectAvoidPredicate, RectReachPredicate + from ds.utils import default_tensor + def eval_reach_avoid(mute=False): """ @@ -132,17 +134,13 @@ def backward(mute=True): @jax.jit def train_step(params, solver_state): # Performs a one step update. - (loss), grad = jax.value_and_grad(form.eval)( - params - ) + (loss), grad = jax.value_and_grad(form.eval)(params) updates, solver_state = solver.update(-grad, solver_state) params = optax.apply_updates(params, updates) return params, solver_state, loss for _ in range(num_iterations): - path, var_solver_state, train_loss = train_step( - path, var_solver_state - ) + path, var_solver_state, train_loss = train_step(path, var_solver_state) loss = form.eval(path) else: diff --git a/linter.sh b/linter.sh new file mode 100755 index 0000000..b84794b --- /dev/null +++ b/linter.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Default source directory and options for the tools +SOURCE_DIR="." +AUTOFIX_OPTIONS="--remove-all-unused-imports --remove-unused-variables --expand-star-imports --ignore-init-module-imports --in-place -r" +ISORT_OPTIONS="--profile black --line-length 88" +BLACK_OPTIONS="--line-length 88" + + +# Install the necessary packages +pip install autoflake isort black + +# Run autoflake with the specified options +output=$(autoflake $SOURCE_DIR $AUTOFIX_OPTIONS) +if [ -n "$output" ]; then + echo "Autoflake made changes or found issues:" + echo "$output" + # Uncomment the next line if you want the script to fail on changes + # exit 1 +else + echo "No issues found by autoflake." +fi + +# Run isort with the specified options +echo "Running isort..." +isort $SOURCE_DIR $ISORT_OPTIONS + +# Run black with the specified options +echo "Running black..." +black $SOURCE_DIR $BLACK_OPTIONS + +# Final message +echo "Linting complete." diff --git a/src/ds/stl.py b/src/ds/stl.py index ec90766..0655818 100644 --- a/src/ds/stl.py +++ b/src/ds/stl.py @@ -2,9 +2,8 @@ import time from abc import abstractmethod from collections import deque -from contextlib import contextmanager -from contextlib import redirect_stdout -from typing import TypeVar, Tuple +from contextlib import contextmanager, redirect_stdout +from typing import Tuple, TypeVar import gurobipy as gp import numpy as np @@ -93,15 +92,15 @@ class GurobiMICPSolver(STLSolver): """ def __init__( - self, - spec, - sys, - x0, - T, - M=1000, - robustness_cost=True, - presolve=True, - verbose=True, + self, + spec, + sys, + x0, + T, + M=1000, + robustness_cost=True, + presolve=True, + verbose=True, ): assert M > 0, "M should be a (large) positive scalar" super().__init__(spec, sys, x0, T, verbose) @@ -157,7 +156,7 @@ def AddQuadraticCost(self, Q, R): self.cost += self.x[:, 0] @ Q @ self.x[:, 0] + self.u[:, 0] @ R @ self.u[:, 0] for t in range(1, self.T): self.cost += ( - self.x[:, t] @ Q @ self.x[:, t] + self.u[:, t] @ R @ self.u[:, t] + self.x[:, t] @ Q @ self.x[:, t] + self.u[:, t] @ R @ self.u[:, t] ) def AddRobustnessCost(self): @@ -320,16 +319,16 @@ def _get_ctrl_system(dim: int): return sys def solve_stlpy_formula( - self, - spec: STLTree, - x0: np.ndarray, - total_time: int, - solver_name="gurobi", - u_bound: tuple = (-20.0, 20.0), - rho_min: float = 0.1, - energy_obj: bool = True, - time_limit=20, - threads=1, + self, + spec: STLTree, + x0: np.ndarray, + total_time: int, + solver_name="gurobi", + u_bound: tuple = (-20.0, 20.0), + rho_min: float = 0.1, + energy_obj: bool = True, + time_limit=20, + threads=1, ) -> Tuple[np.ndarray, dict]: """ Solve the STL formula @@ -375,7 +374,7 @@ def eval_at_t(self, path: Tensor, t: int = 0) -> Tensor: @abstractmethod def eval_whole_path( - self, path: Tensor, start_t: int = 0, end_t: int = None + self, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: raise NotImplementedError @@ -392,7 +391,9 @@ class RectReachPredicate(PredicateBase): Rectangle reachability predicate """ - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5): + def __init__( + self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5 + ): """ :param cent: center of the rectangle :param size: bound of the rectangle @@ -404,11 +405,13 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: self.cent_tensor = default_tensor(cent) self.size_tensor = default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative + self.shrink_factor = ( + shrink_factor # shrink the rectangle to make it more conservative + ) print(f"shrink factor: {shrink_factor}") def eval_whole_path( - self, path: Tensor, start_t: int = 0, end_t: int = None + self, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] @@ -420,7 +423,10 @@ def eval_whole_path( def get_stlpy_form(self) -> STLTree: bounds = np.stack( - [self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2] + [ + self.cent - self.size * self.shrink_factor / 2, + self.cent + self.size * self.shrink_factor / 2, + ] ).T.flatten() return inside_rectangle_formula(bounds, 0, 1, 2, self.name) @@ -444,7 +450,7 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): self.size_tensor = default_tensor(size) def eval_whole_path( - self, path: Tensor, start_t: int = 0, end_t: int = None + self, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] @@ -572,6 +578,7 @@ class STL: """ Class for representing STL formulas. """ + end_t: int def __init__(self, ast: AST): @@ -638,7 +645,7 @@ def _get_end_time(self, ast: AST) -> int: return max(self._get_end_time(ast[1]), self._get_end_time(ast[2])) def _eval( - self, ast: AST, path: Tensor, start_t: int = 0, end_t: int = None + self, ast: AST, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: if self._is_leaf(ast): return ast.eval_at_t(path, start_t) @@ -670,12 +677,12 @@ def _eval( return res def _eval_and( - self, - sub_form1: AST, - sub_form2: AST, - path: Tensor, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: Tensor, + start_t: int = 0, + end_t: int = None, ) -> Tensor: return self._tensor_min( torch.stack( @@ -689,12 +696,12 @@ def _eval_and( ) def _eval_or( - self, - sub_form1: AST, - sub_form2: AST, - path: Tensor, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: Tensor, + start_t: int = 0, + end_t: int = None, ) -> Tensor: return self._tensor_max( torch.stack( @@ -711,12 +718,12 @@ def _eval_not(self, ast: AST, path: Tensor, start_t: int, end_t: int) -> Tensor: return -self._eval(ast, path, start_t, end_t) def _eval_implies( - self, - sub_form1: AST, - sub_form2: AST, - path: Tensor, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: Tensor, + start_t: int = 0, + end_t: int = None, ) -> Tensor: if IMPLIES_TRICK: return self._eval(sub_form1, path, start_t, end_t) * self._eval( @@ -725,7 +732,7 @@ def _eval_implies( return self._eval_or(["~", sub_form1], sub_form2, path, start_t, end_t) def _eval_always( - self, sub_form: AST, path: Tensor, start_t: int, end_t: int + self, sub_form: AST, path: Tensor, start_t: int, end_t: int ) -> Tensor: if self._is_leaf(sub_form): return self._tensor_min( @@ -744,7 +751,7 @@ def _eval_always( return self._tensor_min(val_per_time, dim=-1) def _eval_eventually( - self, sub_form: AST, path: Tensor, start_t: int = 0, end_t: int = None + self, sub_form: AST, path: Tensor, start_t: int = 0, end_t: int = None ) -> Tensor: if self._is_leaf(sub_form): return self._tensor_max( @@ -763,12 +770,12 @@ def _eval_eventually( return self._tensor_max(val_per_time, dim=-1) def _eval_until( - self, - sub_form1: AST, - sub_form2: AST, - path: Tensor, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: Tensor, + start_t: int = 0, + end_t: int = None, ) -> Tensor: if self._is_leaf(sub_form2): till_pred = sub_form2.eval_whole_path(path[:, start_t:end_t]) @@ -785,7 +792,7 @@ def _eval_until( cond = (till_pred > 0).int() index = torch.argmax(cond, dim=-1) for i in range(cond.shape[0]): - cond[i, index[i]:] = 1.0 + cond[i, index[i] :] = 1.0 cond = ~cond.bool() till_pred = torch.where(cond, till_pred, default_tensor(1)) diff --git a/src/ds/stl_jax.py b/src/ds/stl_jax.py index d6494d3..e0afb90 100644 --- a/src/ds/stl_jax.py +++ b/src/ds/stl_jax.py @@ -5,7 +5,6 @@ from contextlib import redirect_stdout from typing import TypeVar -import jax import numpy as np from jax.nn import softmax from stlpy.STL import LinearPredicate, STLTree @@ -18,11 +17,11 @@ import logging -from .stl import colored, HARDNESS, IMPLIES_TRICK, set_hardness - # Replace with JAX import jax.numpy as jnp +from .stl import HARDNESS, IMPLIES_TRICK, colored + class PredicateBase: def __init__(self, name: str): @@ -33,7 +32,7 @@ def eval_at_t(self, path: jnp.ndarray, t: int = 0) -> jnp.ndarray: @abstractmethod def eval_whole_path( - self, path: jnp.ndarray, start_t: int = 0, end_t: int = None + self, path: jnp.ndarray, start_t: int = 0, end_t: int = None ) -> jnp.ndarray: raise NotImplementedError @@ -50,7 +49,9 @@ class RectReachPredicate(PredicateBase): Rectangle reachability predicate """ - def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5): + def __init__( + self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5 + ): """ :param cent: center of the rectangle :param size: bound of the rectangle @@ -62,11 +63,13 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: self.cent_tensor = default_tensor(cent) self.size_tensor = default_tensor(size) - self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative + self.shrink_factor = ( + shrink_factor # shrink the rectangle to make it more conservative + ) print(f"shrink factor: {shrink_factor}") def eval_whole_path( - self, path: jnp.array, start_t: int = 0, end_t: int = None + self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] @@ -78,7 +81,10 @@ def eval_whole_path( def get_stlpy_form(self) -> STLTree: bounds = np.stack( - [self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2] + [ + self.cent - self.size * self.shrink_factor / 2, + self.cent + self.size * self.shrink_factor / 2, + ] ).T.flatten() return inside_rectangle_formula(bounds, 0, 1, 2, self.name) @@ -102,7 +108,7 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str): self.size_tensor = default_tensor(size) def eval_whole_path( - self, path: jnp.array, start_t: int = 0, end_t: int = None + self, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: assert len(path.shape) == 3, "motion must be in batch" eval_path = path[:, start_t:end_t] @@ -295,7 +301,7 @@ def _get_end_time(self, ast: AST) -> int: return max(self._get_end_time(ast[1]), self._get_end_time(ast[2])) def _eval( - self, ast: AST, path: jnp.array, start_t: int = 0, end_t: int = None + self, ast: AST, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: if self._is_leaf(ast): return ast.eval_at_t(path, start_t) @@ -327,12 +333,12 @@ def _eval( return res def _eval_and( - self, - sub_form1: AST, - sub_form2: AST, - path: jnp.array, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: jnp.array, + start_t: int = 0, + end_t: int = None, ) -> jnp.array: return self._tensor_min( jnp.stack( @@ -346,12 +352,12 @@ def _eval_and( ) def _eval_or( - self, - sub_form1: AST, - sub_form2: AST, - path: jnp.array, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: jnp.array, + start_t: int = 0, + end_t: int = None, ) -> jnp.array: return self._tensor_max( jnp.stack( @@ -364,16 +370,18 @@ def _eval_or( axis=-1, ) - def _eval_not(self, ast: AST, path: jnp.array, start_t: int, end_t: int) -> jnp.array: + def _eval_not( + self, ast: AST, path: jnp.array, start_t: int, end_t: int + ) -> jnp.array: return -self._eval(ast, path, start_t, end_t) def _eval_implies( - self, - sub_form1: AST, - sub_form2: AST, - path: jnp.array, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: jnp.array, + start_t: int = 0, + end_t: int = None, ) -> jnp.array: if IMPLIES_TRICK: return self._eval(sub_form1, path, start_t, end_t) * self._eval( @@ -382,7 +390,7 @@ def _eval_implies( return self._eval_or(["~", sub_form1], sub_form2, path, start_t, end_t) def _eval_always( - self, sub_form: AST, path: jnp.array, start_t: int, end_t: int + self, sub_form: AST, path: jnp.array, start_t: int, end_t: int ) -> jnp.array: if self._is_leaf(sub_form): return self._tensor_min( @@ -401,7 +409,7 @@ def _eval_always( return self._tensor_min(val_per_time, axis=-1) def _eval_eventually( - self, sub_form: AST, path: jnp.array, start_t: int = 0, end_t: int = None + self, sub_form: AST, path: jnp.array, start_t: int = 0, end_t: int = None ) -> jnp.array: if self._is_leaf(sub_form): return self._tensor_max( @@ -420,12 +428,12 @@ def _eval_eventually( return self._tensor_max(val_per_time, axis=-1) def _eval_until( - self, - sub_form1: AST, - sub_form2: AST, - path: jnp.array, - start_t: int = 0, - end_t: int = None, + self, + sub_form1: AST, + sub_form2: AST, + path: jnp.array, + start_t: int = 0, + end_t: int = None, ) -> jnp.array: if self._is_leaf(sub_form2): till_pred = sub_form2.eval_whole_path(path[:, start_t:end_t]) @@ -442,7 +450,7 @@ def _eval_until( cond = (till_pred > 0).int() index = jnp.argmax(cond, axis=-1) for i in range(cond.shape[0]): - cond[i, index[i]:] = 1.0 + cond[i, index[i] :] = 1.0 cond = ~cond.bool() till_pred = jnp.where(cond, till_pred, default_tensor(1)) diff --git a/src/ds/utils.py b/src/ds/utils.py index 88040ea..e71e675 100644 --- a/src/ds/utils.py +++ b/src/ds/utils.py @@ -14,13 +14,14 @@ DEFAULT_DEVICE = jax.devices()[0] DEFAULT_DATATYPE = jnp.float32 - def default_tensor(x: np.ndarray, device: str = None, dtype=None) -> jnp.array: - return jax.device_put(jnp.asarray( - x, - dtype=DEFAULT_DATATYPE if dtype is None else dtype, - ), DEFAULT_DEVICE if device is None else device) - + return jax.device_put( + jnp.asarray( + x, + dtype=DEFAULT_DATATYPE if dtype is None else dtype, + ), + DEFAULT_DEVICE if device is None else device, + ) else: # print("Using PyTorch backend") @@ -30,7 +31,6 @@ def default_tensor(x: np.ndarray, device: str = None, dtype=None) -> jnp.array: DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEFAULT_DATATYPE = torch.float32 - def default_tensor(x: np.ndarray, device: str = None, dtype=None) -> torch.Tensor: return torch.tensor( x, diff --git a/tests/test_stl.py b/tests/test_stl.py index 27463c8..f7fd2fe 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -1,6 +1,6 @@ import unittest -from examples.stl.differentiability import eval_reach_avoid, backward +from examples.stl.differentiability import backward, eval_reach_avoid class TestExamples(unittest.TestCase): @@ -23,5 +23,5 @@ def test_run(self): # self.assertEqual(True, False) # add assertion here -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_stl_jax.py b/tests/test_stl_jax.py index a27bfad..ed15770 100644 --- a/tests/test_stl_jax.py +++ b/tests/test_stl_jax.py @@ -2,9 +2,10 @@ import unittest os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX for all child processes -from examples.stl.differentiability import eval_reach_avoid, backward from jax import jit +from examples.stl.differentiability import backward, eval_reach_avoid + class TestJAXExamples(unittest.TestCase): @@ -30,5 +31,5 @@ def test_run(self): # self.assertEqual(True, False) # add assertion here -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()