From 1c79684e45a5692063e28f3bdecf9fd65dd8d24d Mon Sep 17 00:00:00 2001 From: JamesHeald Date: Fri, 1 Nov 2024 07:47:30 +0000 Subject: [PATCH] Optimize the log of the entropy coeff instead of the entropy coeff (#56) * optimize the log of the entropy coeff instead of the entropy coeff * Update log ent coef for SAC and derivates * Reformat yaml * Use uv for faster downloads * Remove TODO * Remove redundant call --------- Co-authored-by: Antonin RAFFIN --- .github/workflows/ci.yml | 67 +++++++++++++++++++++------------------- sbx/crossq/crossq.py | 4 ++- sbx/dqn/policies.py | 1 - sbx/sac/sac.py | 6 ++-- sbx/tqc/tqc.py | 5 +-- sbx/version.txt | 2 +- 6 files changed, 45 insertions(+), 40 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e71b561..70548c8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -23,34 +23,37 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv + # cpu version of pytorch + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu - pip install .[tests] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - # - name: Build the doc - # run: | - # make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - # skip mypy, jax doesn't have its latest version for python 3.8 - if: "!(matrix.python-version == '3.8')" - - name: Test with pytest - run: | - make pytest + uv pip install --system .[tests] + # Use headless version + uv pip install --system opencv-python-headless + - name: Lint with ruff + run: | + make lint + # - name: Build the doc + # run: | + # make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + # skip mypy, jax doesn't have its latest version for python 3.8 + if: "!(matrix.python-version == '3.8')" + - name: Test with pytest + run: | + make pytest diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 94e2bbc..f888672 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -355,8 +355,10 @@ def actor_loss( @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] + ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index d8b19ba..4cff77b 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -107,7 +107,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: ), ) - # TODO: jit qf.apply_fn too? self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign] return key diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 11f8ff5..e3795cd 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -141,8 +141,6 @@ def _setup_model(self) -> None: ent_coef_init = float(self.ent_coef_init.split("_")[1]) assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" - # Note: we optimize the log of the entropy coeff which is slightly different from the paper - # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.ent_coef = EntropyCoef(ent_coef_init) else: # This will throw an error if a malformed string (different from 'auto') is passed @@ -325,8 +323,10 @@ def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState: @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] + ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 5161f4d..f723c31 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -383,9 +383,10 @@ def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) - # ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean() - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] + ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) diff --git a/sbx/version.txt b/sbx/version.txt index c5523bd..6633391 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.17.0 +0.18.0