Skip to content

Latest commit

 

History

History
65 lines (44 loc) · 1.58 KB

README.md

File metadata and controls

65 lines (44 loc) · 1.58 KB

SWAG in Optax

PyPI version

This package implements SWAG as an Optax transform to allow usage with JAX.

Installation

Install from pip as:

pip install optax-swag

To install the latest directly from source, run

pip install git+https://github.com/activatedgeek/optax-swag.git

Usage

To start updating the iterate statistics, use chaining as

import optax
from optax_swag import swag

optimizer = optax.chain(
    ...  ## Other optimizer and transform config.
    swag(freq, rank)  ## Always add as the last transform.
)

The SWAGState object can be accessed from the optimizer state list for downstream usage.

Sampling

A reference code to generate samples from the collected statistics is provided below.

import jax
import jax.numpy as jnp

from optax_swag import sample_swag

swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10

rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)

swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
    jnp.array(samples_rng), swa_opt_state)

The resulting swag_sample_params can now be used for downstream evaluation.

NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.

License

Apache 2.0