Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax Bug Fix #2

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
18cdc0e
fix(jax): rename os env, reload module to make tests work together
jeappen Mar 20, 2024
47e75ae
fix(jax): add evaluation shape test (which is failing). TODO: Debug
jeappen Mar 20, 2024
885e4d3
fix(jax): fix evaluation bug in predicates, jnp.min and torch.min had…
jeappen Mar 20, 2024
677ba06
fix(jax) : change a few more numpy functions -> jax.numpy
jeappen Mar 21, 2024
b9a615d
docs(jax) : update jax instructions and changelog
jeappen Mar 21, 2024
e007146
move common fns to utils so torch is optional
Mar 30, 2024
35a2cd0
fix(stl): fix for using StlPySolver with JAX (temporary conversion to…
jeappen Apr 13, 2024
8e5bf1e
docs(stl): add comments to STLPy tools
jeappen Apr 15, 2024
74828db
feature(stl): add loop and branch specs. Fix always end_time.
jeappen May 15, 2024
3a803ca
feature(realrobot): add script to save output to npy
jeappen May 27, 2024
b1848c8
feature(stl): add loop tests and stlpy solver tests
jeappen Jun 1, 2024
a2c47ac
Merge remote-tracking branch 'origin/feature/jax' into feature/jax
jeappen Jun 1, 2024
1028bd4
fix(stlpy): add tests and fix stlpy conversion
jeappen Jun 3, 2024
51ee629
feature(async): add vanish on end, show spec in pgf plot
jeappen Jun 5, 2024
f181b47
feature(stl): Working on JAX Until
jeappen Aug 8, 2024
dd46986
feature(stl): add single plan no formation
jeappen Aug 9, 2024
6111114
Update README.md
jeappen Oct 12, 2024
d89e4c6
fix(ja): towards STL PyTrees compatible with JAX (helpful for vmappin…
jeappen Oct 14, 2024
d3e1de6
feature(jax): Allows multiple goal (in theory)
jeappen Oct 15, 2024
f341758
fix(jax) : temp fix for RectPredicate not inheriting PredicateBase
jeappen Oct 15, 2024
960a353
feature(jax): add rich repr for center of pred
jeappen Oct 16, 2024
6922e51
feature(jax): add tests towards fixing avoid not working for differen…
jeappen Oct 31, 2024
89b963c
fix(jax): add tolerance value to tests
jeappen Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
/.idea/diff-spec.iml
/.idea/modules.xml
/.idea/inspectionProfiles/profiles_settings.xml
/.idea/workspace.xml
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
#### 2024-03-21

##### Bug Fixes

* **jax:**
* fix evaluation bug in predicates, jnp.min and torch.min had different semantics. Fix examples to allow different
batch sizes as input to STL.eval (885e4d3c)
* add evaluation shape test (which is failing). TODO: Debug (47e75ae5)
* rename os env, reload module to make tests work together (18cdc0ec)
* Using the optax (25687960)

##### Tests

* **jax:** basic jit jax optimization (a930a76e)

#### 2024-02-24

##### Tests
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Connect differentiable components with logical operators.
## Install

```bash
pip install git+https://github.com/ZikangXiong/diff-spec.git
pip install git+https://github.com/jeappen/diff-spec.git@feature/jax
```

## First Order Logic
Expand Down Expand Up @@ -67,12 +67,12 @@ Probability temporal logic is an ongoing work integrating probability and random

If you are using JAX, you can use the JAX backend (stl_jax) and gain immense speedups in many cases.

First set the backend to JAX:
First set the backend to JAX using Environment Variables for our utility functions:

```python
import os

os.environ["JAX_STL_BACKEND"] = "jax" # set the backend to JAX
os.environ["DIFF_STL_BACKEND"] = "jax" # set the backend to JAX (if unset or any other value uses the PyTorch backend)
```

Then you can use the JAX backend to optimize the inputs to satisfy the formula.
Expand Down
113 changes: 75 additions & 38 deletions examples/stl/differentiability.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# %%
import os

import importlib
import matplotlib.pyplot as plt
import numpy as np
import optax
import os

import ds.utils as ds_utils

# if JAX_BACKEND is set the import will be from jax.numpy
if os.environ.get("JAX_STL_BACKEND") == "jax":
if os.environ.get("DIFF_STL_BACKEND") == "jax":
print("Using JAX backend")
from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor

importlib.reload(ds_utils) # Reload the module to reset the backend
import jax
else:
print("Using PyTorch backend")
from ds.stl import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor

importlib.reload(ds_utils) # Reload the module to reset the backend
import torch
from torch.optim import Adam

Expand All @@ -33,7 +37,7 @@ def eval_reach_avoid(mute=False):
form = goal.eventually(0, 10) & obs.always(0, 10)

# Define 2 initial paths in batch
path_1 = default_tensor(
path_1 = ds_utils.default_tensor(
np.array(
[
[
Expand Down Expand Up @@ -66,24 +70,39 @@ def eval_reach_avoid(mute=False):
[1, 1],
[1, 1],
],
[
[9, 9],
[3, 2],
[7, 7],
[6, 6],
[5, 5],
[4, 4],
[3, 3],
[2, 2],
[1, 1],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
]
]
)
)

# eval the formula, default at time 0
res1 = form.eval(path=path_1)
res1 = form.eval(path=path_1) # (+,-,-)
if not mute:
print("eval result at time 0: ", res1)

# eval the formula at time 2
res2 = form.eval(path=path_1, t=2)
res2 = form.eval(path=path_1, t=2) # (+,-,+)
if not mute:
print("eval result at time 2: ", res2)

return res1, res2


def backward(mute=True):
def backward(avoid_spec=False, mute=True):
"""
Planning with gradient descent
"""
Expand All @@ -92,50 +111,66 @@ def backward(mute=True):
# goal_1 is a rectangle area centered in [0, 0] with width and height 1
goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1"))
# goal_2 is a rectangle area centered in [2, 2] with width and height 1
goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2"))
goal_2 = STL(RectReachPredicate(np.array([3, 3]), np.array([1, 1]), "goal_2"))
# goal_2 is a rectangle area centered in [1, 1] with width and height 1
avoid_region = STL(RectAvoidPredicate(np.array([1, 1]), np.array([1, 1]), "avoid_region"))
avoid_region2 = STL(RectAvoidPredicate(np.array([2, 2]), np.array([1, 1]), "avoid_region2"))
avoid_region_goal1 = STL(RectAvoidPredicate(np.array([0, 0]), np.array([1, 1]), "avoid_region_goal1"))
avoid_region_goal2 = STL(RectAvoidPredicate(np.array([3, 3]), np.array([1, 1]), "avoid_region_goal2"))
end_time = 13

if avoid_spec:
print("cover while avoiding avoid_region")
# NOTE: Cover different just alternates between goal_1 and goal_2
form = goal_2.eventually(0, end_time) & goal_1.eventually(0, end_time) \
& avoid_region.always(0, end_time) & avoid_region2.always(0, end_time) \
& avoid_region_goal1.always(end_time // 2, end_time) & avoid_region_goal2.always(0, end_time // 2)
else:
# form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5
# and that holds always in 0 to 8
# In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13
form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8)

# form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5
# and that holds always in 0 to 8
# In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13
form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8)
path = default_tensor(
np.array(
np_path = np.array(
[
[
[
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[1, 0],
[1, 0],
],
]
)
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[1, 0],
[1, 0],
],
]
)

random_like = np.random.rand(*np_path.shape)

path = ds_utils.default_tensor(random_like)
loss = None
lr = 0.1
num_iterations = 1000

if os.environ.get("JAX_STL_BACKEND") == "jax":
if os.environ.get("DIFF_STL_BACKEND") == "jax":

solver = optax.adam(lr)
var_solver_state = solver.init(path)

@jax.jit
def train_step(params, solver_state):
# Performs a one step update.
(loss), grad = jax.value_and_grad(form.eval)(
(loss), grad = jax.value_and_grad(lambda x: -form.eval(x).mean())(
params
)
updates, solver_state = solver.update(-grad, solver_state)
updates, solver_state = solver.update(grad, solver_state)
params = optax.apply_updates(params, updates)
return params, solver_state, loss

Expand All @@ -144,12 +179,14 @@ def train_step(params, solver_state):
path, var_solver_state
)

loss = form.eval(path)
loss = train_loss
else:
# PyTorch backend (slower when num_iterations is high)
path.requires_grad = True
opt = Adam(params=[path], lr=lr)

# ds_utils.HARDNESS = 3.0

for _ in range(num_iterations):
loss = -torch.mean(form.eval(path))
opt.zero_grad()
Expand Down
Loading