Skip to content

Commit

Permalink
Merge branch 'master' into feat/rainbow
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Nov 1, 2024
2 parents 5d206af + 1c79684 commit e018c78
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 38 deletions.
67 changes: 35 additions & 32 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,34 +23,37 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
pip install .[tests]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
# - name: Build the doc
# run: |
# make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
uv pip install --system .[tests]
# Use headless version
uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
make lint
# - name: Build the doc
# run: |
# make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
4 changes: 3 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,10 @@ def actor_loss(
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
6 changes: 3 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def _setup_model(self) -> None:
ent_coef_init = float(self.ent_coef_init.split("_")[1])
assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0"

# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.ent_coef = EntropyCoef(ent_coef_init)
else:
# This will throw an error if a malformed string (different from 'auto') is passed
Expand Down Expand Up @@ -329,8 +327,10 @@ def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState:
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
5 changes: 3 additions & 2 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,10 @@ def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) ->
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
# ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down

0 comments on commit e018c78

Please sign in to comment.