-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4adc362
commit 1c4c385
Showing
15 changed files
with
1,683 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,76 @@ | ||
# image-classification-jax | ||
Image classification in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet | ||
|
||
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet. | ||
|
||
Meant to be simple but good quality. Includes: | ||
- ViT with qk normalization, swiglu, empty registers | ||
- Palm style z-loss | ||
- ability to use schedule-free from `optax.contrib` | ||
- ability to use PSGD optimizers from `psgd-jax` with hessian calc | ||
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet | ||
|
||
Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`). | ||
|
||
|
||
## Installation | ||
|
||
```bash | ||
pip install image-classification-jax | ||
``` | ||
|
||
## Usage | ||
|
||
Set your wandb key either in your python script or through command line: | ||
```bash | ||
export WANDB_API_KEY=<your_key> | ||
``` | ||
|
||
Use `run_experiment` to run an experiment. The following example uses the `xmat` | ||
optimizer from `psgd-jax` wrapped in schedule-free. | ||
|
||
```python | ||
import optax | ||
from image_classification_jax.run_experiment import run_experiment | ||
from psgd_jax.xmat import xmat | ||
|
||
lr = optax.join_schedules( | ||
schedules=[ | ||
optax.linear_schedule(0.0, 0.01, 256), | ||
optax.constant_schedule(0.01), | ||
], | ||
boundaries=[256], | ||
) | ||
|
||
optimizer = optax.contrib.schedule_free(xmat(lr, b1=0.0), learning_rate=lr, b1=0.95) | ||
|
||
run_experiment( | ||
log_to_wandb=True, | ||
wandb_entity="", | ||
wandb_project="image_classification_jax", | ||
wandb_config_update={ | ||
"optimizer": "psgd_xmat", | ||
"schedule_free": True, | ||
"learning_rate": 0.01, | ||
"warmup_steps": 256, | ||
"b1": 0.95, | ||
}, | ||
global_seed=100, | ||
dataset="cifar10", | ||
batch_size=64, | ||
n_epochs=10, | ||
optimizer=optimizer, | ||
compute_in_bfloat16=False, | ||
l2_regularization=1e-4, | ||
randomize_l2_reg=False, | ||
apply_z_loss=True, | ||
model_type="vit", | ||
n_layers=12, | ||
enc_dim=768, | ||
n_heads=12, | ||
n_empty_registers=0, | ||
dropout_rate=0.0, | ||
using_schedule_free=True, | ||
psgd_calc_hessian=True, | ||
psgd_precond_update_prob=0.1, | ||
) | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Optional | ||
|
||
import jax | ||
import numpy as np | ||
from einops import rearrange | ||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
|
||
from image_classification_jax.models.network_utils import normal_init, flax_scan | ||
|
||
|
||
class LearnablePositionalEncoding(nn.Module): | ||
@nn.compact | ||
def __call__(self, x: jax.Array): | ||
assert x.ndim == 3, "Input to LearnablePositionalEncoding must be 3D" | ||
pe = self.param("pe", normal_init, (x.shape[-2], x.shape[-1])) | ||
return x + jnp.expand_dims(pe, axis=0) | ||
|
||
|
||
class SwiGLU(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
x, gates = jnp.split(x, 2, axis=-1) | ||
gates = nn.silu(gates) | ||
return x * gates | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
n_heads: int | ||
dropout_rate: float = 0.0 | ||
is_training: Optional[bool] = None | ||
|
||
@nn.compact | ||
def __call__(self, a, is_training: Optional[bool] = None): | ||
is_training = nn.merge_param("is_training", self.is_training, is_training) | ||
|
||
n_tokens, enc_dim = a.shape[-2:] | ||
|
||
# https://arxiv.org/abs/2302.05442 style but without parallel blocks | ||
a2 = nn.LayerNorm(use_bias=False)(a) | ||
a2 = nn.SelfAttention( | ||
num_heads=self.n_heads, | ||
dropout_rate=self.dropout_rate, | ||
kernel_init=normal_init, | ||
broadcast_dropout=False, | ||
use_bias=False, | ||
normalize_qk=True, | ||
)(a2, deterministic=not is_training) | ||
b = self.param("att_bias", nn.initializers.zeros, (enc_dim,)) | ||
a2 = a2 + jnp.reshape(b, (1, 1, enc_dim)) | ||
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training) | ||
a = a + a2 | ||
|
||
a2 = nn.LayerNorm(use_bias=False)(a) | ||
a2 = nn.Dense(features=int(enc_dim * 8), kernel_init=normal_init)(a2) | ||
a2 = SwiGLU()(a2) | ||
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training) | ||
a2 = nn.Dense(features=enc_dim, kernel_init=normal_init)(a2) | ||
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training) | ||
a = a + a2 | ||
|
||
return a | ||
|
||
|
||
class Transformer(nn.Module): | ||
n_layers: int = 12 | ||
enc_dim: int = 768 | ||
n_heads: int = 12 | ||
n_empty_registers: int = 0 | ||
dropout_rate: float = 0.0 | ||
output_dim: int = 1000 | ||
|
||
@nn.compact | ||
def __call__(self, x, is_training: bool): | ||
x = rearrange(x, "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", p1=16, p2=16) | ||
x = nn.Dense(features=self.enc_dim, kernel_init=normal_init, use_bias=False)(x) | ||
x *= jnp.sqrt(self.enc_dim).astype(x.dtype) | ||
x = LearnablePositionalEncoding()(x) | ||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) | ||
|
||
cls_token = self.param( | ||
"cls_token", nn.initializers.zeros_init(), (self.enc_dim,) | ||
) | ||
cls_token = jnp.tile(jnp.reshape(cls_token, (1, 1, -1)), (x.shape[0], 1, 1)) | ||
x = jnp.concatenate([cls_token, x], axis=1) | ||
|
||
if self.n_empty_registers > 0: | ||
# https://arxiv.org/abs/2309.16588 | ||
empty_registers = self.param( | ||
"registers", | ||
nn.initializers.normal(1 / np.sqrt(self.enc_dim)), | ||
(self.n_empty_registers, x.shape[-1]), | ||
) | ||
empty_registers = jnp.tile( | ||
jnp.expand_dims(empty_registers, axis=0), (x.shape[0], 1, 1) | ||
) | ||
x = jnp.concatenate([x, empty_registers], axis=1) | ||
|
||
x = flax_scan(TransformerBlock, length=self.n_layers, unroll=2)( | ||
n_heads=self.n_heads, | ||
dropout_rate=self.dropout_rate, | ||
is_training=is_training, | ||
)(x) | ||
|
||
x = nn.LayerNorm(use_bias=False)(x[:, 0]) # take cls token | ||
|
||
return nn.Dense( | ||
features=self.output_dim, kernel_init=normal_init, use_bias=False | ||
)(x) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from functools import partial | ||
from typing import Callable, Any, Mapping | ||
|
||
import flax | ||
import flax.linen as nn | ||
from flax.core import FrozenDict | ||
from flax.core.scope import CollectionFilter, PRNGSequenceFilter | ||
from flax.linen.transforms import Target, lift_transform | ||
from flax.typing import InOutScanAxis | ||
|
||
|
||
normal_init = nn.initializers.truncated_normal(stddev=0.02) | ||
|
||
|
||
def _flax_scan( | ||
body_fn: Callable[..., Any], | ||
length: int, | ||
variable_broadcast: CollectionFilter = False, | ||
variable_carry: CollectionFilter = False, | ||
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0}, | ||
split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True}, | ||
unroll: int = 1, | ||
) -> Callable[..., Any]: | ||
scan_fn = partial( | ||
flax.core.lift.scan, | ||
variable_broadcast=variable_broadcast, | ||
variable_carry=variable_carry, | ||
variable_axes=variable_axes, | ||
split_rngs=split_rngs, | ||
unroll=unroll, | ||
) | ||
|
||
def wrapper(scope, carry): | ||
return body_fn(scope, carry), None | ||
|
||
fn = lambda scope, c: scan_fn(wrapper, length=length)(scope, c)[0] | ||
|
||
return fn | ||
|
||
|
||
def flax_scan( | ||
target: Target, | ||
length: int, | ||
variable_broadcast: CollectionFilter = False, | ||
variable_carry: CollectionFilter = False, | ||
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict({True: 0}), | ||
split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({True: True}), | ||
unroll: int = 1, | ||
) -> Target: | ||
return lift_transform( | ||
_flax_scan, | ||
target, | ||
length=length, | ||
variable_broadcast=variable_broadcast, | ||
variable_carry=variable_carry, | ||
variable_axes=variable_axes, | ||
split_rngs=split_rngs, | ||
unroll=unroll, | ||
) |
Oops, something went wrong.