Skip to content

Commit

Permalink
* Updating python annotation to use jax.Array correctly, rather than …
Browse files Browse the repository at this point in the history
…chex.Array.

* As a result updating to newer version of several packages and bumping version to 0.0.4.

PiperOrigin-RevId: 527524457
  • Loading branch information
botev authored and KfacJaxDev committed Apr 27, 2023
1 parent 2a361b2 commit 9b059b2
Show file tree
Hide file tree
Showing 28 changed files with 899 additions and 844 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9"]
os: [ubuntu-latest]

steps:
Expand Down
26 changes: 13 additions & 13 deletions examples/autoencoder_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@
# limitations under the License.
"""Haiku implementation of the standard MNIST Autoencoder."""
import functools
from typing import Dict, Mapping, Tuple, Union
from typing import Mapping, Tuple, Union, Dict

import chex
import haiku as hk
import jax
from jax import nn
import jax.numpy as jnp

import kfac_jax
from examples import losses
from examples import training


Array = kfac_jax.utils.Array
Numeric = kfac_jax.utils.Numeric
PRNGKey = kfac_jax.utils.PRNGKey


def autoencoder() -> hk.Transformed:
"""Constructs a Haiku transformed object of the autoencoder."""
def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array:
def func(batch: Union[Array, Mapping[str, Array]]) -> Array:
"""Evaluates the autoencoder."""
if isinstance(batch, Mapping):
batch = batch["images"]
Expand All @@ -54,11 +59,11 @@ def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array:

def autoencoder_loss(
params: hk.Params,
batch: Union[chex.Array, Mapping[str, chex.Array]],
l2_reg: chex.Numeric,
batch: Union[Array, Mapping[str, Array]],
l2_reg: Numeric,
is_training: bool,
average_loss: bool = True,
) -> Tuple[chex.Array, Dict[str, chex.Array]]:
) -> Tuple[Array, Dict[str, Array]]:
"""Evaluates the loss of the autoencoder."""

if isinstance(batch, Mapping):
Expand All @@ -69,7 +74,7 @@ def autoencoder_loss(
cross_entropy = jnp.sum(losses.sigmoid_cross_entropy(logits, batch), axis=-1)
averaged_cross_entropy = jnp.mean(cross_entropy)

loss = averaged_cross_entropy if average_loss else cross_entropy
loss: Array = averaged_cross_entropy if average_loss else cross_entropy

l2_reg_val = losses.l2_regularizer(params, False, False)
if is_training:
Expand All @@ -88,12 +93,7 @@ def autoencoder_loss(
class AutoencoderMnistExperiment(training.MnistExperiment):
"""Jaxline experiment class for running the MNIST Autoencoder."""

def __init__(
self,
mode: str,
init_rng: chex.PRNGKey,
config,
):
def __init__(self, mode: str, init_rng: PRNGKey, config):
super().__init__(
supervised=False,
flatten_images=True,
Expand Down
18 changes: 11 additions & 7 deletions examples/classifier_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@
# limitations under the License.
"""Haiku implementation of a small convolutional classifier for MNIST."""
import functools
from typing import Dict, Mapping, Tuple, Union
from typing import Mapping, Tuple, Union, Dict

import chex
import haiku as hk
import jax
import jax.numpy as jnp

import kfac_jax
from examples import losses
from examples import training

Array = kfac_jax.utils.Array
Numeric = kfac_jax.utils.Numeric
PRNGKey = kfac_jax.utils.PRNGKey


def convolutional_classifier() -> hk.Transformed:
"""Constructs a Haiku transformed object of the classifier network."""
def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array:
def func(batch: Union[Array, Mapping[str, Array]]) -> Array:
"""Evaluates the classifier."""
if isinstance(batch, Mapping):
batch = batch["images"]
Expand All @@ -52,11 +56,11 @@ def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array:

def classifier_loss(
params: hk.Params,
batch: Mapping[str, chex.Array],
l2_reg: chex.Numeric,
batch: Mapping[str, Array],
l2_reg: Numeric,
is_training: bool,
average_loss: bool = True,
) -> Tuple[chex.Array, Dict[str, chex.Array]]:
) -> Tuple[Array, Dict[str, Array]]:
"""Evaluates the loss of the classifier network."""

logits = convolutional_classifier().apply(params, batch["images"])
Expand All @@ -78,7 +82,7 @@ def classifier_loss(
class ClassifierMnistExperiment(training.MnistExperiment):
"""Jaxline experiment class for running the MNIST classifier."""

def __init__(self, mode: str, init_rng: jnp.ndarray, config):
def __init__(self, mode: str, init_rng: PRNGKey, config):
super().__init__(
supervised=True,
flatten_images=False,
Expand Down
29 changes: 15 additions & 14 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
"""
import types
from typing import Callable, Dict, Iterator, Mapping, Optional, Tuple, TypeVar
from typing import Callable, Iterator, Optional, Tuple, Dict

import chex
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -26,8 +25,10 @@
tfds = tensorflow_datasets

# Types for annotation
T = TypeVar("T")
Batch = Dict[str, chex.Array]
Array = jax.Array
Shape = Tuple[int, ...]
Batch = Dict[str, Array]
TfBatch = Dict[str, tf.Tensor]

# Special global variables
_IMAGENET_MEAN_RGB = (0.485, 0.456, 0.406)
Expand Down Expand Up @@ -147,18 +148,18 @@ def imagenet_num_examples_and_split(
def imagenet_dataset(
split: str,
is_training: bool,
batch_dims: chex.Shape,
batch_dims: Shape,
seed: int = 123,
shuffle_files: bool = True,
buffer_size_factor: int = 10,
shuffle: bool = False,
cache: bool = False,
dtype: jnp.dtype = jnp.float32,
image_size: chex.Shape = (224, 224),
image_size: Shape = (224, 224),
data_dir: Optional[str] = None,
extra_preprocessing_func: Optional[
Callable[[jax.Array, jax.Array],
Tuple[jax.Array, jax.Array]]] = None,
Callable[[Array, Array],
Tuple[Array, Array]]] = None,
) -> Iterator[Batch]:
"""Standard ImageNet dataset pipeline.
Expand Down Expand Up @@ -244,16 +245,16 @@ def imagenet_dataset(
# When training we generate a stateless pipeline, at test we don't need it
def scan_fn(
seed_: tf.Tensor,
data: T
) -> Tuple[tf.Tensor, Tuple[T, tf.Tensor]]:
data: TfBatch,
) -> Tuple[tf.Tensor, Tuple[TfBatch, tf.Tensor]]:
new_seeds = tf.random.experimental.stateless_split(seed_, num=2)
return new_seeds[0], (data, new_seeds[1])

# create a sequence of seeds across cases by repeated splitting
ds = ds.scan(tf_seed, scan_fn)

def preprocess(
example: Mapping[str, tf.Tensor],
example: Dict[str, tf.Tensor],
seed_: Optional[tf.Tensor] = None
) -> Dict[str, tf.Tensor]:

Expand Down Expand Up @@ -302,7 +303,7 @@ def _imagenet_preprocess_image(
image_bytes: tf.Tensor,
seed: tf.Tensor,
is_training: bool,
image_size: chex.Shape,
image_size: Shape,
) -> tf.Tensor:
"""Returns processed and resized images for Imagenet."""

Expand Down Expand Up @@ -367,7 +368,7 @@ def _distorted_bounding_box_crop(
def _decode_and_random_crop(
image_bytes: tf.Tensor,
seed: tf.Tensor,
image_size: chex.Shape = (224, 224),
image_size: Shape = (224, 224),
) -> tf.Tensor:
"""Make a random crop of 224 for Imagenet."""
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
Expand All @@ -390,7 +391,7 @@ def _decode_and_random_crop(
def _decode_and_center_crop(
image_bytes: tf.Tensor,
jpeg_shape: Optional[tf.Tensor] = None,
image_size: chex.Shape = (224, 224),
image_size: Shape = (224, 224),
) -> tf.Tensor:
"""Crops to center of image with padding then scales for Imagenet."""

Expand Down
64 changes: 33 additions & 31 deletions examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for computing and automatically registering losses."""
from typing import Dict, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Dict

import chex
import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
from jax.scipy import special
import kfac_jax

utils = kfac_jax.utils
Array = kfac_jax.utils.Array
Numeric = kfac_jax.utils.Numeric
Params = kfac_jax.utils.Params


def l2_regularizer(
params: kfac_jax.utils.Params,
params: Params,
haiku_exclude_batch_norm: bool,
haiku_exclude_biases: bool,
) -> chex.Array:
) -> Array:
"""Computes an L2 regularizer."""

if haiku_exclude_batch_norm:
params = hk.data_structures.filter(
params = hk.data_structures.filter( # pytype: disable=wrong-arg-types
lambda m, n, p: "batchnorm" not in m, params)

if haiku_exclude_biases:
params = hk.data_structures.filter(
lambda m, n, p: n != "b", params)
params = hk.data_structures.filter( # pytype: disable=wrong-arg-types
lambda m, n, p: n != "b", params
)

return 0.5 * kfac_jax.utils.inner_product(params, params)


def sigmoid_cross_entropy(
logits: chex.Array,
labels: chex.Array,
logits: Array,
labels: Array,
weight: float = 1.0,
register_loss: bool = True,
) -> chex.Array:
) -> Array:
"""Sigmoid cross-entropy loss."""
if register_loss:
kfac_jax.register_sigmoid_cross_entropy_loss(logits, labels, weight)
Expand All @@ -61,12 +63,12 @@ def sigmoid_cross_entropy(


def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
weight: chex.Numeric = 1.0,
logits: Array,
labels: Array,
weight: Numeric = 1.0,
register_loss: bool = True,
mask: Optional[chex.Array] = None,
) -> chex.Array:
mask: Optional[Array] = None,
) -> Array:
"""Softmax cross entropy loss."""

if register_loss:
Expand Down Expand Up @@ -122,11 +124,11 @@ def softmax_cross_entropy(


def squared_error(
prediction: chex.Array,
targets: chex.Array,
prediction: Array,
targets: Array,
weight: float = 1.0,
register_loss: bool = True,
) -> chex.Array:
) -> Array:
"""Squared error loss."""

if prediction.shape != targets.shape:
Expand All @@ -139,10 +141,10 @@ def squared_error(


def top_k_accuracy(
logits_or_probs: chex.Array,
labels: chex.Array,
logits_or_probs: Array,
labels: Array,
k: int = 1,
) -> chex.Array:
) -> Array:
"""Top-k accuracy."""

if labels.ndim == logits_or_probs.ndim:
Expand All @@ -166,11 +168,11 @@ def top_k_accuracy(


def add_label_smoothing(
labels: chex.Array,
labels: Array,
label_smoothing: float,
num_classes: int,
labels_are_one_hot: bool = False,
) -> chex.Array:
) -> Array:
"""Adds label smoothing to the labels."""

if label_smoothing < 0. or label_smoothing > 1.:
Expand All @@ -192,19 +194,19 @@ def add_label_smoothing(


def classifier_loss_and_stats(
logits: chex.Array,
labels_as_int: chex.Array,
params: kfac_jax.utils.Params,
l2_reg: chex.Numeric,
logits: Array,
labels_as_int: Array,
params: Params,
l2_reg: Numeric,
haiku_exclude_batch_norm: bool,
haiku_exclude_biases: bool,
label_smoothing: float = 0.0,
top_k_stats: Sequence[int] = (1, 5),
average_loss: bool = True,
register_loss: bool = True,
mask: Optional[chex.Array] = None,
mask: Optional[Array] = None,
normalization_mode: str = "batch_size_only",
) -> Tuple[chex.Array, Dict[str, chex.Array]]:
) -> Tuple[Array, Dict[str, Array]]:
"""Softmax cross-entropy with regularizer and accuracy statistics."""

batch_size = logits.shape[0]
Expand All @@ -223,7 +225,7 @@ def classifier_loss_and_stats(
weight = 1.0

elif normalization_mode == "all_dims":
weight = 1.0 / utils.product(logits.shape[1:-1])
weight = 1.0 / kfac_jax.utils.product(logits.shape[1:-1])

elif normalization_mode == "all_dims_nonmasked":
assert mask is not None
Expand Down
Loading

0 comments on commit 9b059b2

Please sign in to comment.