Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Zikang Xiong committed Mar 7, 2024
1 parent 266986f commit d84b3c0
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 115 deletions.
16 changes: 7 additions & 9 deletions examples/stl/differentiability.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
# if JAX_BACKEND is set the import will be from jax.numpy
if os.environ.get("JAX_STL_BACKEND") == "jax":
print("Using JAX backend")
import jax

from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor
import jax
else:
print("Using PyTorch backend")
from ds.stl import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor
import torch
from torch.optim import Adam

from ds.stl import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor


def eval_reach_avoid(mute=False):
"""
Expand Down Expand Up @@ -132,17 +134,13 @@ def backward(mute=True):
@jax.jit
def train_step(params, solver_state):
# Performs a one step update.
(loss), grad = jax.value_and_grad(form.eval)(
params
)
(loss), grad = jax.value_and_grad(form.eval)(params)
updates, solver_state = solver.update(-grad, solver_state)
params = optax.apply_updates(params, updates)
return params, solver_state, loss

for _ in range(num_iterations):
path, var_solver_state, train_loss = train_step(
path, var_solver_state
)
path, var_solver_state, train_loss = train_step(path, var_solver_state)

loss = form.eval(path)
else:
Expand Down
33 changes: 33 additions & 0 deletions linter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

# Default source directory and options for the tools
SOURCE_DIR="."
AUTOFIX_OPTIONS="--remove-all-unused-imports --remove-unused-variables --expand-star-imports --ignore-init-module-imports --in-place -r"
ISORT_OPTIONS="--profile black --line-length 88"
BLACK_OPTIONS="--line-length 88"


# Install the necessary packages
pip install autoflake isort black

# Run autoflake with the specified options
output=$(autoflake $SOURCE_DIR $AUTOFIX_OPTIONS)
if [ -n "$output" ]; then
echo "Autoflake made changes or found issues:"
echo "$output"
# Uncomment the next line if you want the script to fail on changes
# exit 1
else
echo "No issues found by autoflake."
fi

# Run isort with the specified options
echo "Running isort..."
isort $SOURCE_DIR $ISORT_OPTIONS

# Run black with the specified options
echo "Running black..."
black $SOURCE_DIR $BLACK_OPTIONS

# Final message
echo "Linting complete."
121 changes: 64 additions & 57 deletions src/ds/stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import time
from abc import abstractmethod
from collections import deque
from contextlib import contextmanager
from contextlib import redirect_stdout
from typing import TypeVar, Tuple
from contextlib import contextmanager, redirect_stdout
from typing import Tuple, TypeVar

import gurobipy as gp
import numpy as np
Expand Down Expand Up @@ -93,15 +92,15 @@ class GurobiMICPSolver(STLSolver):
"""

def __init__(
self,
spec,
sys,
x0,
T,
M=1000,
robustness_cost=True,
presolve=True,
verbose=True,
self,
spec,
sys,
x0,
T,
M=1000,
robustness_cost=True,
presolve=True,
verbose=True,
):
assert M > 0, "M should be a (large) positive scalar"
super().__init__(spec, sys, x0, T, verbose)
Expand Down Expand Up @@ -157,7 +156,7 @@ def AddQuadraticCost(self, Q, R):
self.cost += self.x[:, 0] @ Q @ self.x[:, 0] + self.u[:, 0] @ R @ self.u[:, 0]
for t in range(1, self.T):
self.cost += (
self.x[:, t] @ Q @ self.x[:, t] + self.u[:, t] @ R @ self.u[:, t]
self.x[:, t] @ Q @ self.x[:, t] + self.u[:, t] @ R @ self.u[:, t]
)

def AddRobustnessCost(self):
Expand Down Expand Up @@ -320,16 +319,16 @@ def _get_ctrl_system(dim: int):
return sys

def solve_stlpy_formula(
self,
spec: STLTree,
x0: np.ndarray,
total_time: int,
solver_name="gurobi",
u_bound: tuple = (-20.0, 20.0),
rho_min: float = 0.1,
energy_obj: bool = True,
time_limit=20,
threads=1,
self,
spec: STLTree,
x0: np.ndarray,
total_time: int,
solver_name="gurobi",
u_bound: tuple = (-20.0, 20.0),
rho_min: float = 0.1,
energy_obj: bool = True,
time_limit=20,
threads=1,
) -> Tuple[np.ndarray, dict]:
"""
Solve the STL formula
Expand Down Expand Up @@ -375,7 +374,7 @@ def eval_at_t(self, path: Tensor, t: int = 0) -> Tensor:

@abstractmethod
def eval_whole_path(
self, path: Tensor, start_t: int = 0, end_t: int = None
self, path: Tensor, start_t: int = 0, end_t: int = None
) -> Tensor:
raise NotImplementedError

Expand All @@ -392,7 +391,9 @@ class RectReachPredicate(PredicateBase):
Rectangle reachability predicate
"""

def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5):
def __init__(
self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor: float = 0.5
):
"""
:param cent: center of the rectangle
:param size: bound of the rectangle
Expand All @@ -404,11 +405,13 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str, shrink_factor:

self.cent_tensor = default_tensor(cent)
self.size_tensor = default_tensor(size)
self.shrink_factor = shrink_factor # shrink the rectangle to make it more conservative
self.shrink_factor = (
shrink_factor # shrink the rectangle to make it more conservative
)
print(f"shrink factor: {shrink_factor}")

def eval_whole_path(
self, path: Tensor, start_t: int = 0, end_t: int = None
self, path: Tensor, start_t: int = 0, end_t: int = None
) -> Tensor:
assert len(path.shape) == 3, "motion must be in batch"
eval_path = path[:, start_t:end_t]
Expand All @@ -420,7 +423,10 @@ def eval_whole_path(

def get_stlpy_form(self) -> STLTree:
bounds = np.stack(
[self.cent - self.size * self.shrink_factor / 2, self.cent + self.size * self.shrink_factor / 2]
[
self.cent - self.size * self.shrink_factor / 2,
self.cent + self.size * self.shrink_factor / 2,
]
).T.flatten()
return inside_rectangle_formula(bounds, 0, 1, 2, self.name)

Expand All @@ -444,7 +450,7 @@ def __init__(self, cent: np.ndarray, size: np.ndarray, name: str):
self.size_tensor = default_tensor(size)

def eval_whole_path(
self, path: Tensor, start_t: int = 0, end_t: int = None
self, path: Tensor, start_t: int = 0, end_t: int = None
) -> Tensor:
assert len(path.shape) == 3, "motion must be in batch"
eval_path = path[:, start_t:end_t]
Expand Down Expand Up @@ -572,6 +578,7 @@ class STL:
"""
Class for representing STL formulas.
"""

end_t: int

def __init__(self, ast: AST):
Expand Down Expand Up @@ -638,7 +645,7 @@ def _get_end_time(self, ast: AST) -> int:
return max(self._get_end_time(ast[1]), self._get_end_time(ast[2]))

def _eval(
self, ast: AST, path: Tensor, start_t: int = 0, end_t: int = None
self, ast: AST, path: Tensor, start_t: int = 0, end_t: int = None
) -> Tensor:
if self._is_leaf(ast):
return ast.eval_at_t(path, start_t)
Expand Down Expand Up @@ -670,12 +677,12 @@ def _eval(
return res

def _eval_and(
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
) -> Tensor:
return self._tensor_min(
torch.stack(
Expand All @@ -689,12 +696,12 @@ def _eval_and(
)

def _eval_or(
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
) -> Tensor:
return self._tensor_max(
torch.stack(
Expand All @@ -711,12 +718,12 @@ def _eval_not(self, ast: AST, path: Tensor, start_t: int, end_t: int) -> Tensor:
return -self._eval(ast, path, start_t, end_t)

def _eval_implies(
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
) -> Tensor:
if IMPLIES_TRICK:
return self._eval(sub_form1, path, start_t, end_t) * self._eval(
Expand All @@ -725,7 +732,7 @@ def _eval_implies(
return self._eval_or(["~", sub_form1], sub_form2, path, start_t, end_t)

def _eval_always(
self, sub_form: AST, path: Tensor, start_t: int, end_t: int
self, sub_form: AST, path: Tensor, start_t: int, end_t: int
) -> Tensor:
if self._is_leaf(sub_form):
return self._tensor_min(
Expand All @@ -744,7 +751,7 @@ def _eval_always(
return self._tensor_min(val_per_time, dim=-1)

def _eval_eventually(
self, sub_form: AST, path: Tensor, start_t: int = 0, end_t: int = None
self, sub_form: AST, path: Tensor, start_t: int = 0, end_t: int = None
) -> Tensor:
if self._is_leaf(sub_form):
return self._tensor_max(
Expand All @@ -763,12 +770,12 @@ def _eval_eventually(
return self._tensor_max(val_per_time, dim=-1)

def _eval_until(
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
self,
sub_form1: AST,
sub_form2: AST,
path: Tensor,
start_t: int = 0,
end_t: int = None,
) -> Tensor:
if self._is_leaf(sub_form2):
till_pred = sub_form2.eval_whole_path(path[:, start_t:end_t])
Expand All @@ -785,7 +792,7 @@ def _eval_until(
cond = (till_pred > 0).int()
index = torch.argmax(cond, dim=-1)
for i in range(cond.shape[0]):
cond[i, index[i]:] = 1.0
cond[i, index[i] :] = 1.0
cond = ~cond.bool()
till_pred = torch.where(cond, till_pred, default_tensor(1))

Expand Down
Loading

0 comments on commit d84b3c0

Please sign in to comment.