Skip to content

Commit

Permalink
Fix for new tensorflow probability version (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Mar 31, 2024
1 parent db6120b commit 46dcd7f
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 8 deletions.
3 changes: 1 addition & 2 deletions sbx/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Any, Optional

import jax.numpy as jnp
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


Expand Down
3 changes: 1 addition & 2 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp
from flax.linen.initializers import constant
from flax.training.train_state import TrainState
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy, Flatten

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


Expand Down
3 changes: 1 addition & 2 deletions sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp
from flax.training.train_state import TrainState
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule
Expand All @@ -14,7 +14,6 @@
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


Expand Down
3 changes: 1 addition & 2 deletions sbx/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp
from flax.training.train_state import TrainState
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule
Expand All @@ -14,7 +14,6 @@
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


Expand Down

0 comments on commit 46dcd7f

Please sign in to comment.