Skip to content

Commit

Permalink
initial commit 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Aug 7, 2024
1 parent 4adc362 commit 1c4c385
Show file tree
Hide file tree
Showing 15 changed files with 1,683 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

wandb
76 changes: 75 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,76 @@
# image-classification-jax
Image classification in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet

Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.

Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss
- ability to use schedule-free from `optax.contrib`
- ability to use PSGD optimizers from `psgd-jax` with hessian calc
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet

Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).


## Installation

```bash
pip install image-classification-jax
```

## Usage

Set your wandb key either in your python script or through command line:
```bash
export WANDB_API_KEY=<your_key>
```

Use `run_experiment` to run an experiment. The following example uses the `xmat`
optimizer from `psgd-jax` wrapped in schedule-free.

```python
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.xmat import xmat

lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, 0.01, 256),
optax.constant_schedule(0.01),
],
boundaries=[256],
)

optimizer = optax.contrib.schedule_free(xmat(lr, b1=0.0), learning_rate=lr, b1=0.95)

run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={
"optimizer": "psgd_xmat",
"schedule_free": True,
"learning_rate": 0.01,
"warmup_steps": 256,
"b1": 0.95,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=1e-4,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=12,
enc_dim=768,
n_heads=12,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True,
psgd_calc_hessian=True,
psgd_precond_update_prob=0.1,
)
```
Empty file.
109 changes: 109 additions & 0 deletions image_classification_jax/models/ViT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Optional

import jax
import numpy as np
from einops import rearrange
from flax import linen as nn
from jax import numpy as jnp

from image_classification_jax.models.network_utils import normal_init, flax_scan


class LearnablePositionalEncoding(nn.Module):
@nn.compact
def __call__(self, x: jax.Array):
assert x.ndim == 3, "Input to LearnablePositionalEncoding must be 3D"
pe = self.param("pe", normal_init, (x.shape[-2], x.shape[-1]))
return x + jnp.expand_dims(pe, axis=0)


class SwiGLU(nn.Module):
@nn.compact
def __call__(self, x):
x, gates = jnp.split(x, 2, axis=-1)
gates = nn.silu(gates)
return x * gates


class TransformerBlock(nn.Module):
n_heads: int
dropout_rate: float = 0.0
is_training: Optional[bool] = None

@nn.compact
def __call__(self, a, is_training: Optional[bool] = None):
is_training = nn.merge_param("is_training", self.is_training, is_training)

n_tokens, enc_dim = a.shape[-2:]

# https://arxiv.org/abs/2302.05442 style but without parallel blocks
a2 = nn.LayerNorm(use_bias=False)(a)
a2 = nn.SelfAttention(
num_heads=self.n_heads,
dropout_rate=self.dropout_rate,
kernel_init=normal_init,
broadcast_dropout=False,
use_bias=False,
normalize_qk=True,
)(a2, deterministic=not is_training)
b = self.param("att_bias", nn.initializers.zeros, (enc_dim,))
a2 = a2 + jnp.reshape(b, (1, 1, enc_dim))
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training)
a = a + a2

a2 = nn.LayerNorm(use_bias=False)(a)
a2 = nn.Dense(features=int(enc_dim * 8), kernel_init=normal_init)(a2)
a2 = SwiGLU()(a2)
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training)
a2 = nn.Dense(features=enc_dim, kernel_init=normal_init)(a2)
a2 = nn.Dropout(rate=self.dropout_rate)(a2, deterministic=not is_training)
a = a + a2

return a


class Transformer(nn.Module):
n_layers: int = 12
enc_dim: int = 768
n_heads: int = 12
n_empty_registers: int = 0
dropout_rate: float = 0.0
output_dim: int = 1000

@nn.compact
def __call__(self, x, is_training: bool):
x = rearrange(x, "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", p1=16, p2=16)
x = nn.Dense(features=self.enc_dim, kernel_init=normal_init, use_bias=False)(x)
x *= jnp.sqrt(self.enc_dim).astype(x.dtype)
x = LearnablePositionalEncoding()(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training)

cls_token = self.param(
"cls_token", nn.initializers.zeros_init(), (self.enc_dim,)
)
cls_token = jnp.tile(jnp.reshape(cls_token, (1, 1, -1)), (x.shape[0], 1, 1))
x = jnp.concatenate([cls_token, x], axis=1)

if self.n_empty_registers > 0:
# https://arxiv.org/abs/2309.16588
empty_registers = self.param(
"registers",
nn.initializers.normal(1 / np.sqrt(self.enc_dim)),
(self.n_empty_registers, x.shape[-1]),
)
empty_registers = jnp.tile(
jnp.expand_dims(empty_registers, axis=0), (x.shape[0], 1, 1)
)
x = jnp.concatenate([x, empty_registers], axis=1)

x = flax_scan(TransformerBlock, length=self.n_layers, unroll=2)(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
is_training=is_training,
)(x)

x = nn.LayerNorm(use_bias=False)(x[:, 0]) # take cls token

return nn.Dense(
features=self.output_dim, kernel_init=normal_init, use_bias=False
)(x)
Empty file.
59 changes: 59 additions & 0 deletions image_classification_jax/models/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from functools import partial
from typing import Callable, Any, Mapping

import flax
import flax.linen as nn
from flax.core import FrozenDict
from flax.core.scope import CollectionFilter, PRNGSequenceFilter
from flax.linen.transforms import Target, lift_transform
from flax.typing import InOutScanAxis


normal_init = nn.initializers.truncated_normal(stddev=0.02)


def _flax_scan(
body_fn: Callable[..., Any],
length: int,
variable_broadcast: CollectionFilter = False,
variable_carry: CollectionFilter = False,
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0},
split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True},
unroll: int = 1,
) -> Callable[..., Any]:
scan_fn = partial(
flax.core.lift.scan,
variable_broadcast=variable_broadcast,
variable_carry=variable_carry,
variable_axes=variable_axes,
split_rngs=split_rngs,
unroll=unroll,
)

def wrapper(scope, carry):
return body_fn(scope, carry), None

fn = lambda scope, c: scan_fn(wrapper, length=length)(scope, c)[0]

return fn


def flax_scan(
target: Target,
length: int,
variable_broadcast: CollectionFilter = False,
variable_carry: CollectionFilter = False,
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict({True: 0}),
split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({True: True}),
unroll: int = 1,
) -> Target:
return lift_transform(
_flax_scan,
target,
length=length,
variable_broadcast=variable_broadcast,
variable_carry=variable_carry,
variable_axes=variable_axes,
split_rngs=split_rngs,
unroll=unroll,
)
Loading

0 comments on commit 1c4c385

Please sign in to comment.