From 9b059b2a1bb285ba4de2432cf14a5481ef50fd64 Mon Sep 17 00:00:00 2001 From: Alex Botev Date: Thu, 27 Apr 2023 03:03:47 -0700 Subject: [PATCH] * Updating python annotation to use jax.Array correctly, rather than chex.Array. * As a result updating to newer version of several packages and bumping version to 0.0.4. PiperOrigin-RevId: 527524457 --- .github/workflows/ci.yml | 2 +- examples/autoencoder_mnist/experiment.py | 26 +- examples/classifier_mnist/experiment.py | 18 +- examples/datasets.py | 29 +- examples/losses.py | 64 ++--- examples/lrelunet101_imagenet/experiment.py | 38 +-- examples/optimizers.py | 83 +++--- examples/training.py | 100 +++---- kfac_jax/_src/curvature_blocks.py | 82 +++--- kfac_jax/_src/curvature_estimator.py | 144 +++++----- kfac_jax/_src/layers_and_loss_tags.py | 9 +- kfac_jax/_src/loss_functions.py | 72 ++--- kfac_jax/_src/optimizer.py | 294 ++++++++++---------- kfac_jax/_src/patches_second_moment.py | 89 +++--- kfac_jax/_src/tag_graph_matcher.py | 43 +-- kfac_jax/_src/tracer.py | 163 +++++------ kfac_jax/_src/utils/__init__.py | 14 +- kfac_jax/_src/utils/accumulators.py | 118 +++++--- kfac_jax/_src/utils/math.py | 133 ++++----- kfac_jax/_src/utils/misc.py | 23 +- kfac_jax/_src/utils/parallel.py | 58 ++-- kfac_jax/_src/utils/staging.py | 16 +- kfac_jax/_src/utils/types.py | 79 ++---- readthedocs.yml | 2 +- requirements.txt | 9 +- requirements_tests.txt | 11 +- setup.py | 18 +- tests/test_estimator.py | 6 +- 28 files changed, 899 insertions(+), 844 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index adf7ad1..8c904e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.8", "3.9"] os: [ubuntu-latest] steps: diff --git a/examples/autoencoder_mnist/experiment.py b/examples/autoencoder_mnist/experiment.py index cb1ec91..9e74f82 100644 --- a/examples/autoencoder_mnist/experiment.py +++ b/examples/autoencoder_mnist/experiment.py @@ -13,21 +13,26 @@ # limitations under the License. """Haiku implementation of the standard MNIST Autoencoder.""" import functools -from typing import Dict, Mapping, Tuple, Union +from typing import Mapping, Tuple, Union, Dict -import chex import haiku as hk import jax from jax import nn import jax.numpy as jnp +import kfac_jax from examples import losses from examples import training +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +PRNGKey = kfac_jax.utils.PRNGKey + + def autoencoder() -> hk.Transformed: """Constructs a Haiku transformed object of the autoencoder.""" - def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array: + def func(batch: Union[Array, Mapping[str, Array]]) -> Array: """Evaluates the autoencoder.""" if isinstance(batch, Mapping): batch = batch["images"] @@ -54,11 +59,11 @@ def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array: def autoencoder_loss( params: hk.Params, - batch: Union[chex.Array, Mapping[str, chex.Array]], - l2_reg: chex.Numeric, + batch: Union[Array, Mapping[str, Array]], + l2_reg: Numeric, is_training: bool, average_loss: bool = True, -) -> Tuple[chex.Array, Dict[str, chex.Array]]: +) -> Tuple[Array, Dict[str, Array]]: """Evaluates the loss of the autoencoder.""" if isinstance(batch, Mapping): @@ -69,7 +74,7 @@ def autoencoder_loss( cross_entropy = jnp.sum(losses.sigmoid_cross_entropy(logits, batch), axis=-1) averaged_cross_entropy = jnp.mean(cross_entropy) - loss = averaged_cross_entropy if average_loss else cross_entropy + loss: Array = averaged_cross_entropy if average_loss else cross_entropy l2_reg_val = losses.l2_regularizer(params, False, False) if is_training: @@ -88,12 +93,7 @@ def autoencoder_loss( class AutoencoderMnistExperiment(training.MnistExperiment): """Jaxline experiment class for running the MNIST Autoencoder.""" - def __init__( - self, - mode: str, - init_rng: chex.PRNGKey, - config, - ): + def __init__(self, mode: str, init_rng: PRNGKey, config): super().__init__( supervised=False, flatten_images=True, diff --git a/examples/classifier_mnist/experiment.py b/examples/classifier_mnist/experiment.py index f24b89a..e149467 100644 --- a/examples/classifier_mnist/experiment.py +++ b/examples/classifier_mnist/experiment.py @@ -13,20 +13,24 @@ # limitations under the License. """Haiku implementation of a small convolutional classifier for MNIST.""" import functools -from typing import Dict, Mapping, Tuple, Union +from typing import Mapping, Tuple, Union, Dict -import chex import haiku as hk import jax import jax.numpy as jnp +import kfac_jax from examples import losses from examples import training +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +PRNGKey = kfac_jax.utils.PRNGKey + def convolutional_classifier() -> hk.Transformed: """Constructs a Haiku transformed object of the classifier network.""" - def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array: + def func(batch: Union[Array, Mapping[str, Array]]) -> Array: """Evaluates the classifier.""" if isinstance(batch, Mapping): batch = batch["images"] @@ -52,11 +56,11 @@ def func(batch: Union[chex.Array, Mapping[str, chex.Array]]) -> chex.Array: def classifier_loss( params: hk.Params, - batch: Mapping[str, chex.Array], - l2_reg: chex.Numeric, + batch: Mapping[str, Array], + l2_reg: Numeric, is_training: bool, average_loss: bool = True, -) -> Tuple[chex.Array, Dict[str, chex.Array]]: +) -> Tuple[Array, Dict[str, Array]]: """Evaluates the loss of the classifier network.""" logits = convolutional_classifier().apply(params, batch["images"]) @@ -78,7 +82,7 @@ def classifier_loss( class ClassifierMnistExperiment(training.MnistExperiment): """Jaxline experiment class for running the MNIST classifier.""" - def __init__(self, mode: str, init_rng: jnp.ndarray, config): + def __init__(self, mode: str, init_rng: PRNGKey, config): super().__init__( supervised=True, flatten_images=False, diff --git a/examples/datasets.py b/examples/datasets.py index 718fb5a..345a3d3 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -15,9 +15,8 @@ """ import types -from typing import Callable, Dict, Iterator, Mapping, Optional, Tuple, TypeVar +from typing import Callable, Iterator, Optional, Tuple, Dict -import chex import jax import jax.numpy as jnp import numpy as np @@ -26,8 +25,10 @@ tfds = tensorflow_datasets # Types for annotation -T = TypeVar("T") -Batch = Dict[str, chex.Array] +Array = jax.Array +Shape = Tuple[int, ...] +Batch = Dict[str, Array] +TfBatch = Dict[str, tf.Tensor] # Special global variables _IMAGENET_MEAN_RGB = (0.485, 0.456, 0.406) @@ -147,18 +148,18 @@ def imagenet_num_examples_and_split( def imagenet_dataset( split: str, is_training: bool, - batch_dims: chex.Shape, + batch_dims: Shape, seed: int = 123, shuffle_files: bool = True, buffer_size_factor: int = 10, shuffle: bool = False, cache: bool = False, dtype: jnp.dtype = jnp.float32, - image_size: chex.Shape = (224, 224), + image_size: Shape = (224, 224), data_dir: Optional[str] = None, extra_preprocessing_func: Optional[ - Callable[[jax.Array, jax.Array], - Tuple[jax.Array, jax.Array]]] = None, + Callable[[Array, Array], + Tuple[Array, Array]]] = None, ) -> Iterator[Batch]: """Standard ImageNet dataset pipeline. @@ -244,8 +245,8 @@ def imagenet_dataset( # When training we generate a stateless pipeline, at test we don't need it def scan_fn( seed_: tf.Tensor, - data: T - ) -> Tuple[tf.Tensor, Tuple[T, tf.Tensor]]: + data: TfBatch, + ) -> Tuple[tf.Tensor, Tuple[TfBatch, tf.Tensor]]: new_seeds = tf.random.experimental.stateless_split(seed_, num=2) return new_seeds[0], (data, new_seeds[1]) @@ -253,7 +254,7 @@ def scan_fn( ds = ds.scan(tf_seed, scan_fn) def preprocess( - example: Mapping[str, tf.Tensor], + example: Dict[str, tf.Tensor], seed_: Optional[tf.Tensor] = None ) -> Dict[str, tf.Tensor]: @@ -302,7 +303,7 @@ def _imagenet_preprocess_image( image_bytes: tf.Tensor, seed: tf.Tensor, is_training: bool, - image_size: chex.Shape, + image_size: Shape, ) -> tf.Tensor: """Returns processed and resized images for Imagenet.""" @@ -367,7 +368,7 @@ def _distorted_bounding_box_crop( def _decode_and_random_crop( image_bytes: tf.Tensor, seed: tf.Tensor, - image_size: chex.Shape = (224, 224), + image_size: Shape = (224, 224), ) -> tf.Tensor: """Make a random crop of 224 for Imagenet.""" jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) @@ -390,7 +391,7 @@ def _decode_and_random_crop( def _decode_and_center_crop( image_bytes: tf.Tensor, jpeg_shape: Optional[tf.Tensor] = None, - image_size: chex.Shape = (224, 224), + image_size: Shape = (224, 224), ) -> tf.Tensor: """Crops to center of image with padding then scales for Imagenet.""" diff --git a/examples/losses.py b/examples/losses.py index 98608a4..bfd72f5 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for computing and automatically registering losses.""" -from typing import Dict, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Dict -import chex import haiku as hk import jax from jax import lax @@ -22,33 +21,36 @@ from jax.scipy import special import kfac_jax -utils = kfac_jax.utils +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +Params = kfac_jax.utils.Params def l2_regularizer( - params: kfac_jax.utils.Params, + params: Params, haiku_exclude_batch_norm: bool, haiku_exclude_biases: bool, -) -> chex.Array: +) -> Array: """Computes an L2 regularizer.""" if haiku_exclude_batch_norm: - params = hk.data_structures.filter( + params = hk.data_structures.filter( # pytype: disable=wrong-arg-types lambda m, n, p: "batchnorm" not in m, params) if haiku_exclude_biases: - params = hk.data_structures.filter( - lambda m, n, p: n != "b", params) + params = hk.data_structures.filter( # pytype: disable=wrong-arg-types + lambda m, n, p: n != "b", params + ) return 0.5 * kfac_jax.utils.inner_product(params, params) def sigmoid_cross_entropy( - logits: chex.Array, - labels: chex.Array, + logits: Array, + labels: Array, weight: float = 1.0, register_loss: bool = True, -) -> chex.Array: +) -> Array: """Sigmoid cross-entropy loss.""" if register_loss: kfac_jax.register_sigmoid_cross_entropy_loss(logits, labels, weight) @@ -61,12 +63,12 @@ def sigmoid_cross_entropy( def softmax_cross_entropy( - logits: chex.Array, - labels: chex.Array, - weight: chex.Numeric = 1.0, + logits: Array, + labels: Array, + weight: Numeric = 1.0, register_loss: bool = True, - mask: Optional[chex.Array] = None, -) -> chex.Array: + mask: Optional[Array] = None, +) -> Array: """Softmax cross entropy loss.""" if register_loss: @@ -122,11 +124,11 @@ def softmax_cross_entropy( def squared_error( - prediction: chex.Array, - targets: chex.Array, + prediction: Array, + targets: Array, weight: float = 1.0, register_loss: bool = True, -) -> chex.Array: +) -> Array: """Squared error loss.""" if prediction.shape != targets.shape: @@ -139,10 +141,10 @@ def squared_error( def top_k_accuracy( - logits_or_probs: chex.Array, - labels: chex.Array, + logits_or_probs: Array, + labels: Array, k: int = 1, -) -> chex.Array: +) -> Array: """Top-k accuracy.""" if labels.ndim == logits_or_probs.ndim: @@ -166,11 +168,11 @@ def top_k_accuracy( def add_label_smoothing( - labels: chex.Array, + labels: Array, label_smoothing: float, num_classes: int, labels_are_one_hot: bool = False, -) -> chex.Array: +) -> Array: """Adds label smoothing to the labels.""" if label_smoothing < 0. or label_smoothing > 1.: @@ -192,19 +194,19 @@ def add_label_smoothing( def classifier_loss_and_stats( - logits: chex.Array, - labels_as_int: chex.Array, - params: kfac_jax.utils.Params, - l2_reg: chex.Numeric, + logits: Array, + labels_as_int: Array, + params: Params, + l2_reg: Numeric, haiku_exclude_batch_norm: bool, haiku_exclude_biases: bool, label_smoothing: float = 0.0, top_k_stats: Sequence[int] = (1, 5), average_loss: bool = True, register_loss: bool = True, - mask: Optional[chex.Array] = None, + mask: Optional[Array] = None, normalization_mode: str = "batch_size_only", -) -> Tuple[chex.Array, Dict[str, chex.Array]]: +) -> Tuple[Array, Dict[str, Array]]: """Softmax cross-entropy with regularizer and accuracy statistics.""" batch_size = logits.shape[0] @@ -223,7 +225,7 @@ def classifier_loss_and_stats( weight = 1.0 elif normalization_mode == "all_dims": - weight = 1.0 / utils.product(logits.shape[1:-1]) + weight = 1.0 / kfac_jax.utils.product(logits.shape[1:-1]) elif normalization_mode == "all_dims_nonmasked": assert mask is not None diff --git a/examples/lrelunet101_imagenet/experiment.py b/examples/lrelunet101_imagenet/experiment.py index 9f16a6e..8c8885c 100644 --- a/examples/lrelunet101_imagenet/experiment.py +++ b/examples/lrelunet101_imagenet/experiment.py @@ -15,15 +15,21 @@ import functools from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union -import chex import haiku as hk from jax import nn import jax.numpy as jnp +import kfac_jax from examples import losses from examples import training -from ml_collections import config_dict +import ml_collections import numpy as np + +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +PRNGKey = kfac_jax.utils.PRNGKey +Shape = kfac_jax.utils.Shape +DType = kfac_jax.utils.DType FloatStrOrBool = Union[str, float, bool] @@ -47,7 +53,7 @@ def __init__(self, scale: float = 1.0, axis: int = -1): self.scale = scale self.axis = axis - def __call__(self, shape: chex.Shape, dtype: chex.ArrayDType) -> chex.Array: # pytype: disable=signature-mismatch # numpy-scalars + def __call__(self, shape: Shape, dtype: DType) -> Array: # pytype: disable=signature-mismatch # numpy-scalars # This has essentially copied from https://github.com/deepmind/dks if self.axis != -1: @@ -139,7 +145,7 @@ def __init__( self.layers = layers self.activation = activation - def __call__(self, inputs: chex.Array, **_: Any) -> chex.Array: + def __call__(self, inputs: Array, **_: Any) -> Array: out = inputs for conv_i in self.layers: @@ -177,7 +183,7 @@ def __init__( name=f"block_{i}" )) - def __call__(self, inputs: chex.Array, **kwargs: Any) -> chex.Array: + def __call__(self, inputs: Array, **kwargs: Any) -> Array: out = inputs for block in self.blocks: out = block(out, **kwargs) @@ -295,10 +301,10 @@ def __init__( def __call__( self, - inputs: chex.Array, + inputs: Array, is_training: bool, **kwargs: Any - ) -> chex.Array: + ) -> Array: out = inputs out = self.initial_conv(out) out = hk.max_pool( @@ -323,9 +329,9 @@ def lrelunet( ) -> hk.Transformed: """Constructs a Haiku transformed object of the LReLUNet101 network.""" def func( - batch: Union[chex.Array, Mapping[str, chex.Array]], + batch: Union[Array, Mapping[str, Array]], is_training: bool - ) -> chex.Array: + ) -> Array: """Evaluates the network.""" if isinstance(batch, dict): batch = batch["images"] @@ -336,18 +342,18 @@ def func( def lrelunet_loss( params: hk.Params, - rng: chex.PRNGKey, - batch: Mapping[str, chex.Array], + rng: PRNGKey, + batch: Mapping[str, Array], is_training: bool, - l2_reg: chex.Numeric, + l2_reg: Numeric, label_smoothing: float = 0.1, average_loss: bool = True, num_classes: int = 1000, depth: int = 101, **kwargs: Any, ) -> Tuple[ - chex.Array, - Union[Dict[str, chex.Array], Tuple[hk.State, Dict[str, chex.Array]]] + Array, + Union[Dict[str, Array], Tuple[hk.State, Dict[str, Array]]] ]: """Evaluates the loss of the LReLUNet model.""" logits = lrelunet(num_classes=num_classes, depth=depth, **kwargs).apply( @@ -371,8 +377,8 @@ class LReLUNetImageNetExperiment(training.ImageNetExperiment): def __init__( self, mode: str, - init_rng: chex.PRNGKey, - config: config_dict.ConfigDict + init_rng: PRNGKey, + config: ml_collections.ConfigDict, ): """Initializes the network instance.""" super().__init__( diff --git a/examples/optimizers.py b/examples/optimizers.py index 4878fcc..f2aa15e 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -13,10 +13,9 @@ # limitations under the License. """Utilities for setting up different optimizers.""" import functools -from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Type, Tuple, Union from absl import logging -import chex import jax import jax.numpy as jnp import kfac_jax @@ -24,6 +23,13 @@ import optax +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +PRNGKey = kfac_jax.utils.PRNGKey +Params = kfac_jax.utils.Params +Batch = kfac_jax.utils.Batch +# FuncState = kfac_jax.utils.FuncState +FuncState = Any OptaxState = Any @@ -81,21 +87,21 @@ def __init__( def init( self, - params: kfac_jax.utils.Params, - rng: jnp.ndarray, - batch: kfac_jax.utils.Batch, - func_state: Optional[kfac_jax.utils.FuncState] = None + params: Params, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState] = None, ) -> OptaxState: """Initializes the optimizer and returns the appropriate optimizer state.""" return self._jit_init(params, rng, batch, func_state) def _step( self, - params: kfac_jax.utils.Params, + params: Params, state: OptaxState, - rng: chex.PRNGKey, - batch: kfac_jax.utils.Batch, - func_state: Optional[kfac_jax.utils.FuncState] = None, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState] = None, ) -> kfac_jax.optimizer.ReturnEither: """A single step of optax.""" batch = self._batch_process_func(batch) @@ -110,7 +116,8 @@ def _step( has_aux=self._value_func_has_aux, has_state=self._value_func_has_state, ) - stats["loss"] = loss # pytype: disable=unsupported-operands # numpy-scalars + stats = stats or {} + stats["loss"] = loss stats, grads = jax.lax.pmean((stats, grads), axis_name="optax_axis") # Compute and apply updates via our optimizer. @@ -128,16 +135,16 @@ def _step( def step( self, - params: kfac_jax.utils.Params, + params: Params, state: OptaxState, - rng: jnp.ndarray, - data_iterator: Iterator[kfac_jax.utils.Batch], - func_state: Optional[kfac_jax.utils.FuncState] = None, + rng: PRNGKey, + data_iterator: Iterator[Batch], + func_state: Optional[FuncState] = None, global_step_int: Optional[int] = None - ) -> Union[Tuple[kfac_jax.utils.Params, Any, kfac_jax.utils.FuncState, - Mapping[str, jnp.ndarray]], - Tuple[kfac_jax.utils.Params, Any, - Mapping[str, jnp.ndarray]]]: + ) -> Union[ + Tuple[Params, Any, FuncState, Mapping[str, Array]], + Tuple[Params, Any, Mapping[str, Array]], + ]: """A step with similar interface to KFAC.""" result = self._jit_step( params=params, @@ -155,7 +162,7 @@ def step( def tf1_rmsprop( - learning_rate_fn: Callable[[chex.Numeric], chex.Numeric], + learning_rate_fn: Callable[[Numeric], Numeric], decay: float = .9, momentum: float = 0., epsilon: float = 1e-8 @@ -186,9 +193,9 @@ def update_fn(updates, state, params=None): def linear_interpolation( - x: chex.Numeric, + x: Numeric, interpolation_points: Tuple[Tuple[float, float], ...] -) -> chex.Array: +) -> Array: """Performs linear interpolation between the interpolation points.""" xs, ys = zip(*interpolation_points) masks = [x < ci for ci in xs[1:]] @@ -214,11 +221,11 @@ def linear_interpolation( def imagenet_sgd_schedule( - global_step: chex.Numeric, + global_step: Numeric, dataset_size: int, train_total_batch_size: int, **_: Any, -) -> chex.Array: +) -> Array: """Standard linear scaling schedule for ImageNet.""" # Can be found in Section 5.1 of https://arxiv.org/pdf/1706.02677.pdf steps_per_epoch = dataset_size / train_total_batch_size @@ -233,18 +240,18 @@ def imagenet_sgd_schedule( def fixed_schedule( - global_step: chex.Numeric, - value: chex.Numeric, + global_step: Numeric, + value: Numeric, **_: Any, -) -> chex.Array: +) -> Array: """Fixed/constant schedule.""" return jnp.ones_like(global_step) * value def kfac_resnet50_schedule( - global_step: chex.Numeric, + global_step: Numeric, **_: Any, -) -> chex.Array: +) -> Array: """Custom schedule for KFAC.""" return jnp.power(10.0, linear_interpolation( x=global_step, @@ -255,7 +262,7 @@ def kfac_resnet50_schedule( def cosine_schedule( - global_step: chex.Numeric, + global_step: Numeric, dataset_size: int, train_total_batch_size: int, epochs: Optional[float], @@ -264,7 +271,7 @@ def cosine_schedule( warmup_epochs: Optional[float] = None, warmup_steps: Optional[int] = None, **_: Any, -) -> chex.Array: +) -> Array: """A cosine schedule described in the TAT paper.""" if (steps is None) == (epochs is None): @@ -292,7 +299,7 @@ def cosine_schedule( def stepwise_schedule( - global_step: chex.Numeric, + global_step: Numeric, dataset_size: int, train_total_batch_size: int, lr_decay_factors: Sequence[float], @@ -302,7 +309,7 @@ def stepwise_schedule( step_boundaries: Optional[Sequence[float]] = None, warmup_steps: Optional[int] = None, **_: Any, -) -> chex.Array: +) -> Array: """A basic stepwise schedule.""" if (epoch_boundaries is None) == (step_boundaries is None): @@ -334,7 +341,7 @@ def stepwise_schedule( def construct_schedule( name: str, **kwargs, -) -> Callable[[chex.Numeric], chex.Array]: +) -> Callable[[Numeric], Array]: """Constructs the actual schedule from its name and extra kwargs.""" if name == "fixed": return functools.partial(fixed_schedule, **kwargs) @@ -350,9 +357,9 @@ def construct_schedule( raise NotImplementedError(name) -def kfac_bn_registration_kwargs(bn_registration: str) -> Mapping[str, Union[ - Tuple[str, ...], - Mapping[str, Type[kfac_jax.CurvatureBlock]]]]: +def kfac_bn_registration_kwargs(bn_registration: str) -> Mapping[ + str, Union[Tuple[str, ...], Mapping[str, Type[kfac_jax.CurvatureBlock]]] +]: """Constructs KFAC kwargs for the given batch-norm registration strategy.""" if bn_registration == "generic": return dict(patterns_to_skip=("scale_and_shift", "scale_only")) @@ -371,7 +378,7 @@ def create_optimizer( name: str, config: config_dict.ConfigDict, train_model_func: kfac_jax.optimizer.ValueFunc, - l2_reg: chex.Numeric, + l2_reg: Numeric, has_aux: bool, has_func_state: bool, has_rng: bool, diff --git a/examples/training.py b/examples/training.py index 376b749..3fafbf9 100644 --- a/examples/training.py +++ b/examples/training.py @@ -17,10 +17,9 @@ import functools import os import time -from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Callable, Iterator, Optional, Tuple, Union, Dict from absl import logging -import chex import jax import jax.numpy as jnp from jaxline import experiment @@ -28,12 +27,18 @@ import kfac_jax from examples import datasets from examples import optimizers -from ml_collections import config_dict -import numpy as np +import ml_collections # Types for annotation -InitFunc = Callable[[chex.PRNGKey, kfac_jax.utils.Batch], kfac_jax.utils.Params] +Array = kfac_jax.utils.Array +Numeric = kfac_jax.utils.Numeric +PRNGKey = kfac_jax.utils.PRNGKey +Params = kfac_jax.utils.Params +Batch = kfac_jax.utils.Batch +FuncState = kfac_jax.utils.FuncState + +InitFunc = Callable[[PRNGKey, Batch], Params] class SupervisedExperiment(experiment.AbstractExperiment): @@ -70,8 +75,8 @@ class SupervisedExperiment(experiment.AbstractExperiment): def __init__( self, mode: str, - init_rng: chex.PRNGKey, - config: config_dict.ConfigDict, + init_rng: PRNGKey, + config: ml_collections.ConfigDict, init_parameters_func: InitFunc, model_loss_func: kfac_jax.optimizer.ValueFunc, has_aux: bool, @@ -244,17 +249,15 @@ def eval_total_batch_size(self) -> int: @property @functools.lru_cache(maxsize=1) - def train_inputs(self) -> Union[ - Iterator[kfac_jax.utils.Batch], - Tuple[Iterator[kfac_jax.utils.Batch], Iterator[kfac_jax.utils.Batch]], - ]: + def train_inputs(self) -> Union[Iterator[Batch], + Tuple[Iterator[Batch], Iterator[Batch]]]: """The training data iterator.""" return self._train_input def progress( self, - global_step: chex.Numeric, - ) -> chex.Numeric: + global_step: Numeric, + ) -> Numeric: """Computes the current progress of the training as a number in [0,1].""" if self.config.training.steps is not None: @@ -269,17 +272,15 @@ def progress( def should_run_step( self, global_step: int, - config: config_dict.ConfigDict, + config: ml_collections.ConfigDict, ) -> bool: del config # not used return int(self.progress(global_step)) < 1 - def create_optimizer(self) -> Union[ - optimizers.OptaxWrapper, - kfac_jax.Optimizer, - ]: + def create_optimizer(self) -> Union[optimizers.OptaxWrapper, + kfac_jax.Optimizer]: """Creates the optimizer specified in the experiment's config.""" optimizer_config = copy.deepcopy(self.config.optimizer) return optimizers.create_optimizer( @@ -375,13 +376,13 @@ def _build_train_input( ) -> datasets.tf.data.Dataset: """Constructs the training dataset.""" - def step( # pytype: disable=signature-mismatch # jax-ndarray + def step( # pytype: disable=signature-mismatch self, - global_step: jnp.ndarray, - rng: jnp.ndarray, - **unused_args: Any - ) -> Dict[str, jnp.ndarray]: - del global_step # Instead, we use the self._python_step + global_step: Array, + rng: PRNGKey, + **unused_args: Any, + ) -> Dict[str, Numeric]: + del global_step # Perform optimizer step result = self.optimizer.step( @@ -415,7 +416,7 @@ def step( # pytype: disable=signature-mismatch # jax-ndarray for i in range(gathered_stat.shape[0]): stats[f"{name}_{i}"] = jnp.array([gathered_stat[i]]) - return kfac_jax.utils.get_first(stats) # questionable? + return kfac_jax.utils.get_first(stats) # _ # _____ ____ _| | @@ -435,13 +436,13 @@ def _build_eval_input( def _evaluate_single_batch( self, - global_step: jnp.ndarray, - params: kfac_jax.utils.Params, - func_state: kfac_jax.utils.FuncState, + global_step: Array, + params: Params, + func_state: FuncState, opt_state: Union[kfac_jax.Optimizer.State, optimizers.OptaxState], - rng: chex.PRNGKey, - batch: kfac_jax.utils.Batch, - ) -> Dict[str, chex.Array]: + rng: PRNGKey, + batch: Batch, + ) -> Dict[str, Array]: """Evaluates a single batch.""" del global_step # This might be used in subclasses @@ -462,16 +463,14 @@ def _evaluate_single_batch( if hasattr(opt_state, "data_seen"): stats["data_seen"] = opt_state.data_seen - return kfac_jax.utils.pmean_if_pmap(stats, "eval_axis") # pytype: disable=bad-return-type # numpy-scalars + return kfac_jax.utils.pmean_if_pmap(stats, "eval_axis") # pytype: disable=bad-return-type - def evaluate( # pytype: disable=signature-mismatch # numpy-scalars + def evaluate( # pytype: disable=signature-mismatch self, - global_step: chex.Array, - rng: chex.PRNGKey, - writer: Optional[pipe_utils.Writer], - ) -> Dict[str, chex.Array]: - del writer # not used - + global_step: Array, + rng: PRNGKey, + **unused_args: Any, + ) -> Dict[str, Numeric]: all_stats = dict() # Evaluates both the train and eval split metrics @@ -491,7 +490,7 @@ def evaluate( # pytype: disable=signature-mismatch # numpy-scalars averaged_stats.add(stats, 1) # Extract all stats - for k, v in averaged_stats.value.items(): # pytype: disable=attribute-error # numpy-scalars + for k, v in averaged_stats.value.items(): # pytype: disable=attribute-error all_stats[f"{name}_{k}"] = kfac_jax.utils.get_first(v) logging.info("Evaluation for %s is completed with %d number of batches.", @@ -499,16 +498,16 @@ def evaluate( # pytype: disable=signature-mismatch # numpy-scalars all_stats["progress"] = self.progress(self._python_step) - return jax.tree_util.tree_map(np.array, all_stats) + return all_stats def train_standalone_supervised( random_seed: int, - full_config: config_dict.ConfigDict, + full_config: ml_collections.ConfigDict, experiment_ctor: - Callable[[str, chex.PRNGKey, config_dict.ConfigDict], SupervisedExperiment], + Callable[[str, PRNGKey, ml_collections.ConfigDict], SupervisedExperiment], storage_folder: Optional[str], -) -> Dict[str, chex.Array]: +) -> Dict[str, Array]: """Run an experiment without the Jaxline runtime.""" rng = jax.random.PRNGKey(random_seed) @@ -552,7 +551,7 @@ def train_standalone_supervised( stats["time"] = stats.get("time", []) + [elapsed_time] for k in sorted(scalars): - stats.setdefault(k, []).append(scalars[k]) + stats.setdefault(k, []).append(jnp.asarray(scalars[k])) # Logging if i % full_config.log_tensors_interval == 0: @@ -564,11 +563,12 @@ def train_standalone_supervised( logging.info("-" * 20) i += 1 + stats = {k: jnp.stack(v) for k, v in stats.items()} if storage_folder is not None: jnp.savez(f"{storage_folder}/snapshot_final.npz", *jax.tree_util.tree_leaves(experiment_instance.snapshot_state())) jnp.savez(f"{storage_folder}/stats.npz", **stats) - return stats # pytype: disable=bad-return-type # numpy-scalars + return stats class MnistExperiment(SupervisedExperiment): @@ -579,8 +579,8 @@ def __init__( supervised: bool, flatten_images: bool, mode: str, - init_rng: jnp.ndarray, - config: config_dict.ConfigDict, + init_rng: PRNGKey, + config: ml_collections.ConfigDict, init_parameters_func: InitFunc, model_loss_func: kfac_jax.optimizer.ValueFunc, has_aux: bool, @@ -654,8 +654,8 @@ class ImageNetExperiment(SupervisedExperiment): def __init__( self, mode: str, - init_rng: chex.PRNGKey, - config: config_dict.ConfigDict, + init_rng: PRNGKey, + config: ml_collections.ConfigDict, init_parameters_func: InitFunc, model_loss_func: kfac_jax.optimizer.ValueFunc, has_aux: bool, diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 5558bb0..1a21368 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -15,9 +15,8 @@ import abc import collections import functools -from typing import Any, Dict, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import Optional, Sequence, Any, Set, Tuple, Union, Dict -import chex import jax import jax.numpy as jnp from kfac_jax._src import layers_and_loss_tags as tags @@ -27,11 +26,12 @@ import numpy as np # Types for annotation -Numeric = chex.Numeric -Scalar = chex.Scalar -PRNGKey = chex.PRNGKey -Shape = chex.Shape -Array = chex.Array +Array = utils.Array +Scalar = utils.Scalar +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +Shape = utils.Shape +DType = utils.DType ScalarOrSequence = Union[Scalar, Sequence[Scalar]] # Special global variables @@ -205,7 +205,7 @@ def parameters_shapes(self) -> Tuple[Shape, ...]: lambda x: tuple(x.aval.shape), self.parameter_variables)) @property - def dtype(self) -> chex.ArrayDType: + def dtype(self) -> DType: dtypes = set(p.aval.dtype for p in self.parameter_variables) # pytype: disable=attribute-error if len(dtypes) > 1: raise ValueError("Not all parameters are the same dtype.") @@ -233,7 +233,7 @@ def number_of_parameters(self) -> int: def dim(self) -> int: """The number of elements of all parameter variables together.""" - return sum(utils.product(shape) for shape in self.parameters_shapes) # pytype: disable=bad-return-type # numpy-scalars + return sum(utils.product(shape) for shape in self.parameters_shapes) def scale(self, state: "CurvatureBlock.State", use_cache: bool) -> Numeric: """A scalar pre-factor of the curvature approximation. @@ -451,7 +451,7 @@ def _eigenvalues_unscaled( def update_curvature_matrix_estimate( self, state: "CurvatureBlock.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -605,7 +605,7 @@ def _eigenvalues_unscaled( def update_curvature_matrix_estimate( self, state: CurvatureBlock.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -656,7 +656,7 @@ def _init( return Diagonal.State( cache=None, - diagonal_factors=tuple(utils.WeightedMovingAverage.zero( + diagonal_factors=tuple(utils.WeightedMovingAverage.zeros_array( shape, self.dtype) for shape in self.parameters_shapes), ) @@ -693,7 +693,7 @@ def _eigenvalues_unscaled( def update_curvature_matrix_estimate( self, state: "Diagonal.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -713,22 +713,21 @@ def update_curvature_matrix_estimate( def _update_curvature_matrix_estimate( self, state: "Diagonal.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, ) -> "Diagonal.State": pass - def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars + def _update_cache( self, state: "Diagonal.State", identity_weight: Numeric, - exact_powers: Numeric, - approx_powers: Numeric, + exact_powers: Set[Scalar], + approx_powers: Set[Scalar], eigenvalues: bool, ) -> "Diagonal.State": - return state.copy() def _to_dense_unscaled(self, state: "Diagonal.State") -> Array: @@ -851,7 +850,7 @@ def _init( return Full.State( cache=cache, - matrix=utils.WeightedMovingAverage.zero( + matrix=utils.WeightedMovingAverage.zeros_array( [self.dim, self.dim], self.dtype), ) @@ -909,7 +908,7 @@ def _eigenvalues_unscaled( def update_curvature_matrix_estimate( self, state: "Full.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -928,7 +927,7 @@ def update_curvature_matrix_estimate( def _update_curvature_matrix_estimate( self, state: "Full.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1079,9 +1078,9 @@ def _init( return TwoKroneckerFactored.State( cache=cache, - inputs_factor=utils.WeightedMovingAverage.zero( + inputs_factor=utils.WeightedMovingAverage.zeros_array( [d_in, d_in], self.dtype), - outputs_factor=utils.WeightedMovingAverage.zero( + outputs_factor=utils.WeightedMovingAverage.zeros_array( [d_out, d_out], self.dtype), ) @@ -1166,7 +1165,7 @@ def _eigenvalues_unscaled( def update_curvature_matrix_estimate( self, state: "TwoKroneckerFactored.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1186,22 +1185,21 @@ def update_curvature_matrix_estimate( def _update_curvature_matrix_estimate( self, state: "TwoKroneckerFactored.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, ) -> "TwoKroneckerFactored.State": pass - def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars + def _update_cache( self, state: "TwoKroneckerFactored.State", identity_weight: Numeric, - exact_powers: Numeric, - approx_powers: Numeric, + exact_powers: Set[Scalar], + approx_powers: Set[Scalar], eigenvalues: bool, ) -> "TwoKroneckerFactored.State": - # Copy this first since we mutate it later in this function. state = state.copy() @@ -1275,7 +1273,7 @@ class NaiveDiagonal(Diagonal): def _update_curvature_matrix_estimate( self, state: "NaiveDiagonal.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1302,7 +1300,7 @@ class NaiveFull(Full): def _update_curvature_matrix_estimate( self, state: Full.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1340,7 +1338,7 @@ def has_bias(self) -> bool: def _update_curvature_matrix_estimate( self, state: "Diagonal.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1372,7 +1370,7 @@ class DenseFull(Full): def _update_curvature_matrix_estimate( self, state: "Full.State", - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1414,7 +1412,7 @@ def output_size(self) -> int: def _update_curvature_matrix_estimate( self, state: TwoKroneckerFactored.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1518,7 +1516,7 @@ def conv2d_tangent_squared( def _update_curvature_matrix_estimate( self, state: Diagonal.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1620,7 +1618,7 @@ def conv2d_tangent_outer_product( def _update_curvature_matrix_estimate( self, state: Full.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1660,10 +1658,10 @@ def weights_output_channel_index(self) -> int: @property def weights_spatial_size(self) -> int: """The spatial filter size of the weights.""" - return utils.product(self.weights_spatial_shape) # pytype: disable=bad-return-type # numpy-scalars + return utils.product(self.weights_spatial_shape) @property - def weights_spatial_shape(self) -> chex.Shape: + def weights_spatial_shape(self) -> Shape: spatial_index = self._layer_tag_eq.params["dimension_numbers"].rhs_spec[2:] return tuple(self.parameters_shapes[0][i] for i in spatial_index) @@ -1705,7 +1703,7 @@ def num_locations( def compute_inputs_stats( self, inputs: Array, - ) -> chex.Array: + ) -> Array: """Computes the statistics for the inputs factor.""" # Note that the input statistics are computed and stored with an extra @@ -1774,7 +1772,7 @@ def compute_outputs_stats( def _update_curvature_matrix_estimate( self, state: TwoKroneckerFactored.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1850,7 +1848,7 @@ def has_shift(self) -> bool: def _update_curvature_matrix_estimate( self, state: Diagonal.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, @@ -1911,7 +1909,7 @@ def _has_shift(self) -> bool: def _update_curvature_matrix_estimate( self, state: Full.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: Dict[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index db88ca6..58bd99a 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -49,9 +49,8 @@ """ import abc import functools -from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Mapping, Generic, TypeVar, Tuple, Union, Dict -import chex import jax from jax import scipy import jax.numpy as jnp @@ -63,17 +62,22 @@ import numpy as np # Types for annotation +Array = utils.Array +PRNGKey = utils.PRNGKey +Numeric = utils.Numeric +Scalar = utils.Scalar +Shape = utils.Shape CurvatureBlockCtor = Callable[ [tags.LayerTagEqn, str], curvature_blocks.CurvatureBlock ] -StateType = Any +StateType = TypeVar("StateType") # Special global variables _ESTIMATION_MODES = ("fisher_gradients", "fisher_empirical", "fisher_exact", "fisher_curvature_prop", "ggn_exact", "ggn_curvature_prop") -_DEFAULT_TAG_TO_BLOCK_CTOR: MutableMapping[str, CurvatureBlockCtor] = dict( +_DEFAULT_TAG_TO_BLOCK_CTOR: Dict[str, CurvatureBlockCtor] = dict( dense_tag=curvature_blocks.DenseTwoKroneckerFactored, conv2d_tag=curvature_blocks.Conv2DTwoKroneckerFactored, generic_tag=curvature_blocks.NaiveDiagonal, @@ -109,7 +113,7 @@ def __init__( self, func: utils.Func, params_index: int = 0, - batch_size_extractor: Callable[[utils.Batch], chex.Numeric] = + batch_size_extractor: Callable[[utils.Batch], Numeric] = utils.default_batch_size_extractor, ): """Initializes the ImplicitExactCurvature instance. @@ -136,7 +140,7 @@ def __init__( ) self._batch_size_extractor = batch_size_extractor - def batch_size(self, func_args: utils.FuncArgs) -> chex.Numeric: + def batch_size(self, func_args: utils.FuncArgs) -> Numeric: """The expected batch size given a list of loss instances.""" return self._batch_size_extractor(func_args[-1]) @@ -144,8 +148,8 @@ def batch_size(self, func_args: utils.FuncArgs) -> chex.Numeric: def _multiply_loss_fisher( cls, losses: Sequence[loss_functions.NegativeLogProbLoss], - loss_vectors: Sequence[Sequence[chex.Array]] - ) -> Tuple[Tuple[chex.Array, ...], ...]: + loss_vectors: Sequence[Sequence[Array]] + ) -> Tuple[Tuple[Array, ...], ...]: """Multiplies ``loss_vectors`` by the Fisher of the total loss.""" assert len(losses) == len(loss_vectors) return tuple(loss.multiply_fisher(vec) @@ -155,8 +159,8 @@ def _multiply_loss_fisher( def _multiply_loss_ggn( cls, losses: Sequence[loss_functions.LossFunction], - loss_vectors: Sequence[Sequence[chex.Array]] - ) -> Tuple[Tuple[chex.Array, ...], ...]: + loss_vectors: Sequence[Sequence[Array]] + ) -> Tuple[Tuple[Array, ...], ...]: """Multiplies ``loss_vectors`` by the GGN of the total loss.""" return tuple(loss.multiply_ggn(vec) for loss, vec in zip(losses, loss_vectors)) @@ -165,8 +169,8 @@ def _multiply_loss_ggn( def _multiply_loss_fisher_factor( cls, losses: Sequence[loss_functions.NegativeLogProbLoss], - loss_inner_vectors: Sequence[chex.Array], - ) -> Tuple[Tuple[chex.Array, ...], ...]: + loss_inner_vectors: Sequence[Array], + ) -> Tuple[Tuple[Array, ...], ...]: """Multiplies the vectors with the Fisher factors of each loss. Args: @@ -186,8 +190,8 @@ def _multiply_loss_fisher_factor( def _multiply_loss_ggn_factor( cls, losses: Sequence[loss_functions.LossFunction], - loss_inner_vectors: Sequence[chex.Array], - ) -> Tuple[Tuple[chex.Array, ...], ...]: + loss_inner_vectors: Sequence[Array], + ) -> Tuple[Tuple[Array, ...], ...]: """Multiplies the vectors with the GGN factors of each loss. Args: @@ -206,8 +210,8 @@ def _multiply_loss_ggn_factor( def _multiply_loss_fisher_factor_transpose( cls, losses: Sequence[loss_functions.NegativeLogProbLoss], - loss_vectors: Sequence[Sequence[chex.Array]] - ) -> Tuple[chex.Array, ...]: + loss_vectors: Sequence[Sequence[Array]] + ) -> Tuple[Array, ...]: """Multiplies the vectors with the transposed Fisher factors of each loss. Args: @@ -227,8 +231,8 @@ def _multiply_loss_fisher_factor_transpose( def _multiply_loss_ggn_factor_transpose( cls, losses: Sequence[loss_functions.LossFunction], - loss_vectors: Sequence[Sequence[chex.Array]] - ) -> Tuple[chex.Array, ...]: + loss_vectors: Sequence[Sequence[Array]] + ) -> Tuple[Array, ...]: """Multiplies the vectors with the transposed GGN factors of each loss. Args: @@ -355,7 +359,7 @@ def multiply_fisher_factor_transpose( self, func_args: utils.FuncArgs, parameter_structured_vector: utils.Params, - ) -> Tuple[chex.Array, ...]: + ) -> Tuple[Array, ...]: """Multiplies the vector with the transposed factor of the Fisher matrix. Args: @@ -384,7 +388,7 @@ def multiply_ggn_factor_transpose( self, func_args: utils.FuncArgs, parameter_structured_vector: utils.Params, - ) -> Tuple[chex.Array, ...]: + ) -> Tuple[Array, ...]: """Multiplies the vector with the transposed factor of the GGN matrix. Args: @@ -406,7 +410,7 @@ def multiply_ggn_factor_transpose( def multiply_fisher_factor( self, func_args: utils.FuncArgs, - loss_inner_vectors: Sequence[chex.Array], + loss_inner_vectors: Sequence[Array], ) -> utils.Params: """Multiplies the vector with the factor of the Fisher matrix. @@ -439,7 +443,7 @@ def multiply_fisher_factor( def multiply_ggn_factor( self, func_args: utils.FuncArgs, - loss_inner_vectors: Sequence[chex.Array], + loss_inner_vectors: Sequence[Array], ) -> utils.Params: """Multiplies the vector with the factor of the GGN matrix. @@ -467,7 +471,7 @@ def multiply_ggn_factor( def multiply_jacobian_transpose( self, func_args: utils.FuncArgs, - loss_input_vectors: Sequence[Sequence[chex.Array]], + loss_input_vectors: Sequence[Sequence[Array]], ) -> utils.Params: """Multiplies a vector by the model's transposed Jacobian. @@ -488,7 +492,7 @@ def get_loss_inner_vector_shapes_and_batch_size( self, func_args: utils.FuncArgs, mode: str - ) -> Tuple[Tuple[chex.Shape, ...], int]: + ) -> Tuple[Tuple[Shape, ...], int]: """Get shapes of loss inner vectors, and the batch size. Args: @@ -513,7 +517,7 @@ def get_loss_inner_vector_shapes_and_batch_size( def get_loss_input_shapes_and_batch_size( self, func_args: utils.FuncArgs - ) -> Tuple[Tuple[Tuple[chex.Shape, ...], ...], int]: + ) -> Tuple[Tuple[Tuple[Shape, ...], ...], int]: """Get shapes of loss input vectors, and the batch size. Args: @@ -531,7 +535,7 @@ def get_loss_input_shapes_and_batch_size( batch_size) -class CurvatureEstimator(utils.Finalizable): +class CurvatureEstimator(Generic[StateType], utils.Finalizable): """An abstract curvature estimator class. This is a class that abstracts away the process of estimating a curvature @@ -590,7 +594,7 @@ def dim(self) -> int: @abc.abstractmethod def init( self, - rng: chex.PRNGKey, + rng: PRNGKey, func_args: utils.FuncArgs, exact_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], approx_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], @@ -626,8 +630,8 @@ def multiply_matpower( self, state: StateType, parameter_structured_vector: utils.Params, - identity_weight: Union[chex.Array, Sequence[float], float], - power: Union[float, int], + identity_weight: Numeric, + power: Scalar, exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], @@ -662,7 +666,7 @@ def multiply( self, state: StateType, parameter_structured_vector: utils.Params, - identity_weight: Union[chex.Array, Sequence[float], float], + identity_weight: Numeric, exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], @@ -683,7 +687,7 @@ def multiply_inverse( self, state: StateType, parameter_structured_vector: utils.Params, - identity_weight: Union[chex.Array, Sequence[float], float], + identity_weight: Numeric, exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], @@ -705,7 +709,7 @@ def eigenvalues( self, state: StateType, use_cached: bool, - ) -> chex.Array: + ) -> Array: """Computes the eigenvalues of the curvature matrix. Args: @@ -723,10 +727,10 @@ def eigenvalues( def update_curvature_matrix_estimate( self, state: StateType, - ema_old: chex.Numeric, - ema_new: chex.Numeric, - batch_size: chex.Numeric, - rng: chex.PRNGKey, + ema_old: Numeric, + ema_new: Numeric, + batch_size: Numeric, + rng: PRNGKey, func_args: utils.FuncArgs, pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, @@ -785,7 +789,7 @@ def update_curvature_matrix_estimate( def update_cache( self, state: StateType, - identity_weight: chex.Numeric, + identity_weight: Numeric, exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, @@ -815,11 +819,12 @@ def update_cache( """ @abc.abstractmethod - def to_dense_matrix(self, state: StateType) -> chex.Array: + def to_dense_matrix(self, state: StateType) -> Array: """Returns an explicit dense array representing the curvature matrix.""" -class BlockDiagonalCurvature(CurvatureEstimator): +class BlockDiagonalCurvature( + CurvatureEstimator["BlockDiagonalCurvature.State"]): """Block diagonal curvature estimator class.""" @utils.pytree_dataclass @@ -952,7 +957,7 @@ def num_blocks(self) -> int: return len(self.blocks) @property - def block_dims(self) -> chex.Shape: + def block_dims(self) -> Shape: """The number of elements of all parameter variables for each block.""" return tuple(block.dim for block in self.blocks) @@ -990,7 +995,7 @@ def params_block_index(self) -> utils.Params: which approximates the part of the curvature matrix associated with the parameter. """ - params_block_index: List[Optional[int]] = [None] * self.num_params_variables + params_block_index: list[Optional[int]] = [None] * self.num_params_variables for i, block_indices in enumerate(self.jaxpr.layer_indices): for index in block_indices: @@ -1014,11 +1019,11 @@ def _compute_losses_vjp(self, func_args: utils.FuncArgs): def params_vector_to_blocks_vectors( self, parameter_structured_vector: utils.Params, - ) -> Tuple[Tuple[chex.Array, ...]]: + ) -> Tuple[Tuple[Array, ...]]: """Splits the parameters to values for each corresponding block.""" params_values_flat = jax.tree_util.tree_leaves(parameter_structured_vector) - blocks_vectors: List[Tuple[chex.Array, ...]] = [] + blocks_vectors: list[Tuple[Array, ...]] = [] for indices in self.jaxpr.layer_indices: blocks_vectors.append(tuple(params_values_flat[i] for i in indices)) @@ -1027,7 +1032,7 @@ def params_vector_to_blocks_vectors( def blocks_vectors_to_params_vector( self, - blocks_vectors: Sequence[Sequence[chex.Array]], + blocks_vectors: Sequence[Sequence[Array]], ) -> utils.Params: """Reverses the effect of ``self.vectors_to_blocks``.""" @@ -1035,7 +1040,7 @@ def blocks_vectors_to_params_vector( raise ValueError("Incorrect number of block vectors. Expected " f"{self.num_blocks}, but got {len(blocks_vectors)}.") - values_flat: List[Optional[chex.Array]] = [None] * self.num_params_variables + values_flat: list[Optional[Array]] = [None] * self.num_params_variables for idx, (indices, vectors) in enumerate( zip(self.jaxpr.layer_indices, blocks_vectors)): @@ -1059,7 +1064,7 @@ def _finalize(self, func_args: utils.FuncArgs): @utils.auto_scope_method def init( self, - rng: chex.PRNGKey, + rng: PRNGKey, func_args: utils.FuncArgs, exact_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], approx_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], @@ -1088,8 +1093,8 @@ def multiply_matpower( self, state: "BlockDiagonalCurvature.State", parameter_structured_vector: utils.Params, - identity_weight: Union[Sequence[chex.Numeric], chex.Numeric], - power: Union[float, int], + identity_weight: Union[Numeric, Sequence[Numeric]], + power: Scalar, exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], @@ -1135,7 +1140,7 @@ def block_eigenvalues( self, state: "BlockDiagonalCurvature.State", use_cached: bool, - ) -> Tuple[chex.Array, ...]: + ) -> Tuple[Array, ...]: """Computes the eigenvalues for each block of the curvature estimator. Args: @@ -1159,7 +1164,7 @@ def eigenvalues( self, state: "BlockDiagonalCurvature.State", use_cached: bool, - ) -> chex.Array: + ) -> Array: blocks_eigenvalues = self.block_eigenvalues(state, use_cached) return jnp.concatenate(blocks_eigenvalues, axis=0) @@ -1168,10 +1173,10 @@ def eigenvalues( def update_curvature_matrix_estimate( self, state: "BlockDiagonalCurvature.State", - ema_old: chex.Numeric, - ema_new: chex.Numeric, - batch_size: chex.Numeric, - rng: chex.PRNGKey, + ema_old: Numeric, + ema_new: Numeric, + batch_size: Numeric, + rng: PRNGKey, func_args: utils.FuncArgs, pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, @@ -1283,7 +1288,7 @@ def update_blocks(vjp_vec_, state_, ema_old_, ema_new_): else: vjp_vec[i] = loss.multiply_ggn_factor_replicated_one_hot([index]) - if utils.is_array_instance(vjp_vec[i]): + if isinstance(vjp_vec[i], Array): # In the special case of only one parameter, it still needs to be a # tuple for the tangents. vjp_vec[i] = (vjp_vec[i],) @@ -1304,7 +1309,7 @@ def update_blocks(vjp_vec_, state_, ema_old_, ema_new_): def update_cache( self, state: "BlockDiagonalCurvature.State", - identity_weight: Union[Sequence[chex.Numeric], chex.Numeric], + identity_weight: Union[Numeric, Sequence[Numeric]], exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, @@ -1339,8 +1344,9 @@ def filter_outputs(thunk, vals): matches = jax.tree_util.tree_map(lambda o, v: o is v, thunk(), vals) def new_thunk(): - return jax.tree_util.tree_map(lambda o, m: None if m else o, - thunk(), matches) + return jax.tree_util.tree_map( + lambda o, m: None if m else o, thunk(), matches + ) return new_thunk # Create new thunks that only return the state arrays that they actually @@ -1365,7 +1371,7 @@ def new_thunk(): def to_diagonal_block_dense_matrix( self, state: "BlockDiagonalCurvature.State", - ) -> Tuple[chex.Array, ...]: + ) -> Tuple[Array, ...]: """Returns a tuple of arrays with explicit dense matrices of each block.""" return tuple(block.to_dense_matrix(block_state) for block, block_state in zip(self.blocks, state.blocks_states)) @@ -1374,7 +1380,7 @@ def to_diagonal_block_dense_matrix( def to_dense_matrix( self, state: "BlockDiagonalCurvature.State" - ) -> chex.Array: + ) -> Array: return scipy.linalg.block_diag(*self.to_diagonal_block_dense_matrix(state)) @@ -1496,7 +1502,7 @@ def modified_losses_jvp(vjp_vec): # Need to reorder all of the block information to follow the canonical # order of variables params_vars = BlockDiagonalCurvature.params_vector_to_blocks_vectors( - self, self.jaxpr.params_vars) + self, self.jaxpr.params_vars) # pytype: disable=wrong-arg-types order = np.argsort([p.count for p in jax.tree_util.tree_leaves(params_vars)]) @@ -1507,13 +1513,13 @@ def modified_losses_jvp(vjp_vec): def params_vector_to_blocks_vectors( self, parameter_structured_vector: utils.Params, - ) -> Tuple[Tuple[chex.Array, ...]]: + ) -> Tuple[Tuple[Array, ...]]: return (tuple(jax.tree_util.tree_leaves(parameter_structured_vector)),) def blocks_vectors_to_params_vector( self, - blocks_vectors: Sequence[Sequence[chex.Array]], + blocks_vectors: Sequence[Sequence[Array]], ) -> utils.Params: assert len(blocks_vectors) == self.num_blocks @@ -1524,10 +1530,10 @@ def blocks_vectors_to_params_vector( def update_curvature_matrix_estimate( self, state: BlockDiagonalCurvature.State, - ema_old: chex.Numeric, - ema_new: chex.Numeric, - batch_size: chex.Numeric, - rng: chex.PRNGKey, + ema_old: Numeric, + ema_new: Numeric, + batch_size: Numeric, + rng: PRNGKey, func_args: utils.FuncArgs, pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, @@ -1536,7 +1542,7 @@ def update_curvature_matrix_estimate( rng = jax.random.split(rng, batch_size) def single_state_update( - index: chex.Numeric, + index: Numeric, state_: curvature_blocks.Full.State ) -> curvature_blocks.Full.State: @@ -1564,7 +1570,7 @@ def single_state_update( def update_cache( self, state: BlockDiagonalCurvature.State, - identity_weight: chex.Numeric, + identity_weight: Numeric, exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, diff --git a/kfac_jax/_src/layers_and_loss_tags.py b/kfac_jax/_src/layers_and_loss_tags.py index 3d36448..6f0dab5 100644 --- a/kfac_jax/_src/layers_and_loss_tags.py +++ b/kfac_jax/_src/layers_and_loss_tags.py @@ -13,18 +13,17 @@ # limitations under the License. """K-FAC losses and layers tagging Jax primitives.""" import types -from typing import Any, Generic, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Generic, Optional, Sequence, Type, TypeVar, Tuple, Union -import chex import jax from jax import core from jax.interpreters import batching as jax_batching # Types for annotation T = TypeVar("T") -ArrayOrXla = TypeVar("ArrayOrXla", chex.Array, jax.interpreters.xla.XlaOp) -Array = chex.Array +Array = jax.Array Arrays = Tuple[Array, ...] +ArrayOrXla = TypeVar("ArrayOrXla", Array, jax.interpreters.xla.XlaOp) class LossTag(core.Primitive, Generic[T]): @@ -133,7 +132,7 @@ def _xla_translation( *args: jax.interpreters.xla.XlaOp, args_names: Sequence[str], ) -> Tuple[jax.interpreters.xla.XlaOp, ...]: - """The XLA translation rule for this primitive (creates a no-op Tuple).""" + """The XLA translation rule for this primitive (creates a no-op tuple).""" del avals_in, avals_out # not used return self.get_outputs(*args, args_names=args_names) diff --git a/kfac_jax/_src/loss_functions.py b/kfac_jax/_src/loss_functions.py index 1b85749..593a9ab 100644 --- a/kfac_jax/_src/loss_functions.py +++ b/kfac_jax/_src/loss_functions.py @@ -15,7 +15,6 @@ import abc from typing import Optional, Sequence, Tuple -import chex import distrax import jax import jax.numpy as jnp @@ -24,7 +23,11 @@ from kfac_jax._src import utils -Array = chex.Array +Array = utils.Array +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +Shape = utils.Shape +DType = utils.DType class LossFunction(utils.Finalizable): @@ -36,7 +39,7 @@ class LossFunction(utils.Finalizable): needed. """ - def __init__(self, weight: chex.Numeric): + def __init__(self, weight: Numeric): """Initializes the loss instance. Args: @@ -50,11 +53,11 @@ def __init__(self, weight: chex.Numeric): self.finalize() @property - def dtype(self) -> chex.ArrayDType: + def dtype(self) -> DType: return self.parameter_dependants[0].dtype @property - def weight(self) -> chex.Numeric: + def weight(self) -> Numeric: """The relative weight of the loss.""" return self._weight @@ -75,7 +78,7 @@ def num_parameter_dependants(self) -> int: @property @abc.abstractmethod - def parameter_independants(self) -> Tuple[chex.Numeric, ...]: + def parameter_independants(self) -> Tuple[Numeric, ...]: """All the parameter independent arrays of the loss.""" @property @@ -289,7 +292,7 @@ def multiply_ggn_factor_replicated_one_hot_unweighted( @property @abc.abstractmethod - def ggn_factor_inner_shape(self) -> chex.Shape: + def ggn_factor_inner_shape(self) -> Shape: """The shape of the array returned by `self.multiply_ggn_factor`.""" @@ -440,11 +443,11 @@ def multiply_fisher_factor_replicated_one_hot_unweighted( @property @abc.abstractmethod - def fisher_factor_inner_shape(self) -> chex.Shape: + def fisher_factor_inner_shape(self) -> Shape: """The shape of the array returned by :func:`~LossFunction.multiply_fisher_factor`.""" @abc.abstractmethod - def sample(self, rng: chex.PRNGKey) -> Array: + def sample(self, rng: PRNGKey) -> Array: """Sample ``targets`` from the underlying distribution.""" def grad_of_evaluate_on_sample( @@ -501,7 +504,7 @@ def multiply_ggn_factor_replicated_one_hot_unweighted( return self.multiply_fisher_factor_replicated_one_hot_unweighted(index) @property - def ggn_factor_inner_shape(self) -> chex.Shape: + def ggn_factor_inner_shape(self) -> Shape: return self.fisher_factor_inner_shape @@ -514,13 +517,14 @@ def dist(self) -> distrax.Distribution: """The underlying Distrax distribution.""" def _evaluate(self, targets: Array) -> Array: - return -self.dist.log_prob(targets) # keeps leading dims intact + # keeps leading dims intact + return -self.dist.log_prob(targets) # pytype: disable=bad-return-type - def sample(self, rng: chex.PRNGKey) -> Array: - return self.dist.sample(seed=rng) # pytype: disable=bad-return-type # numpy-scalars + def sample(self, rng: PRNGKey) -> Array: + return self.dist.sample(seed=rng) # pytype: disable=bad-return-type @property - def fisher_factor_inner_shape(self) -> chex.Shape: + def fisher_factor_inner_shape(self) -> Shape: return jax.eval_shape( lambda: self.sample(rng=jax.random.PRNGKey(0))).shape @@ -542,8 +546,8 @@ def __init__( self, mean: Array, targets: Optional[Array] = None, - variance: chex.Numeric = 0.5, - weight: chex.Numeric = 1.0, + variance: Numeric = 0.5, + weight: Numeric = 1.0, ): """Initializes the loss instance. @@ -567,7 +571,7 @@ def mean(self) -> Array: return self._mean @property - def variance(self) -> chex.Numeric: + def variance(self) -> Numeric: return self._variance @property @@ -575,7 +579,7 @@ def targets(self) -> Optional[Array]: return self._targets @property - def parameter_independants(self) -> Tuple[chex.Numeric, ...]: + def parameter_independants(self) -> Tuple[Numeric, ...]: arrays = (self.variance, self.weight) if self._targets is not None: arrays = (self._targets,) + arrays @@ -666,7 +670,7 @@ def __init__( mean: Array, variance: Array, targets: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Initializes the loss instance. @@ -690,7 +694,7 @@ def targets(self) -> Optional[Array]: return self._targets @property - def parameter_independants(self) -> Tuple[chex.Numeric, ...]: + def parameter_independants(self) -> Tuple[Numeric, ...]: arrays = (self.weight,) if self._targets is not None: arrays = (self._targets,) + arrays @@ -793,7 +797,7 @@ def multiply_fisher_factor_replicated_one_hot_unweighted( return mean_output, var_output @property - def fisher_factor_inner_shape(self) -> chex.Shape: + def fisher_factor_inner_shape(self) -> Shape: return self._mean.shape[:-1] + self._mean.shape[-1:] * 2 def multiply_ggn_unweighted( @@ -820,7 +824,7 @@ def multiply_ggn_factor_replicated_one_hot_unweighted( raise NotImplementedError() @property - def ggn_factor_inner_shape(self) -> chex.Shape: + def ggn_factor_inner_shape(self) -> Shape: raise NotImplementedError() @@ -840,7 +844,7 @@ def __init__( self, logits: Array, targets: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Initializes the loss instance. @@ -858,7 +862,7 @@ def targets(self) -> Optional[Array]: return self._targets @property - def parameter_independants(self) -> Tuple[chex.Numeric, ...]: + def parameter_independants(self) -> Tuple[Numeric, ...]: arrays = (self.weight,) if self._targets is not None: arrays = (self._targets,) + arrays @@ -871,7 +875,7 @@ def dist(self) -> distrax.Bernoulli: @property def _probs(self) -> Array: """The probabilities of the underlying Bernoulli distribution.""" - return self.dist.probs + return self.dist.probs # pytype: disable=bad-return-type @property def params(self) -> Tuple[Array]: @@ -933,7 +937,7 @@ def __init__( logits: Array, targets: Optional[Array] = None, mask: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Initializes the loss instance. @@ -969,7 +973,7 @@ def mask(self) -> Optional[Array]: return self._mask @property - def parameter_independants(self) -> Tuple[chex.Numeric, ...]: + def parameter_independants(self) -> Tuple[Numeric, ...]: arrays = (self.weight,) if self.mask is not None: @@ -1017,7 +1021,7 @@ def params(self) -> Tuple[Array]: return (self._logits,) @property - def fisher_factor_inner_shape(self) -> chex.Shape: + def fisher_factor_inner_shape(self) -> Shape: return self._logits.shape def copy_with_different_inputs( @@ -1184,7 +1188,7 @@ def register_normal_predictive_distribution( mean: Array, targets: Optional[Array] = None, variance: float = 0.5, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Registers a normal predictive distribution. @@ -1224,7 +1228,7 @@ def register_normal_predictive_distribution( def register_squared_error_loss( prediction: Array, targets: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ) -> Array: """Registers a squared error loss function. @@ -1250,7 +1254,7 @@ def register_squared_error_loss( def register_multi_bernoulli_predictive_distribution( logits: Array, targets: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Registers a multi-Bernoulli predictive distribution. @@ -1285,7 +1289,7 @@ def register_multi_bernoulli_predictive_distribution( def register_sigmoid_cross_entropy_loss( logits: Array, targets: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Registers a sigmoid cross-entropy loss function. @@ -1314,7 +1318,7 @@ def register_categorical_predictive_distribution( logits: Array, targets: Optional[Array] = None, mask: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ): """Registers a categorical predictive distribution. @@ -1380,7 +1384,7 @@ def register_softmax_cross_entropy_loss( logits: Array, targets: Optional[Array] = None, mask: Optional[Array] = None, - weight: chex.Numeric = 1.0, + weight: Numeric = 1.0, ) -> Array: """Registers a softmax cross-entropy loss function. diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index c37b1dd..560f1e7 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -15,9 +15,8 @@ """K-FAC optimizer.""" import functools -from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Iterator, Optional, Sequence, Any, Generic, Tuple, Union, Dict -import chex import jax from jax import lax import jax.numpy as jnp @@ -26,27 +25,36 @@ from typing_extensions import TypeAlias # Types for annotation +Array = utils.Array +PRNGKey = utils.PRNGKey +Numeric = utils.Numeric +Params = utils.Params +Batch = utils.Batch +FuncState = Any +# FuncState = utils.FuncState +FuncAux = utils.FuncAux + OptimizerState: TypeAlias = "Optimizer.State" -ScheduleType = Callable[[chex.Array], Optional[chex.Array]] +ScheduleType = Callable[[Array], Optional[Array]] FuncArgsVariants = Union[ - Tuple[utils.Params, utils.Batch], - Tuple[utils.Params, utils.FuncState, utils.Batch], - Tuple[utils.Params, chex.PRNGKey, utils.Batch], - Tuple[utils.Params, utils.FuncState, chex.PRNGKey, utils.Batch], + Tuple[Params, Batch], + Tuple[Params, FuncState, Batch], + Tuple[Params, PRNGKey, Batch], + Tuple[Params, FuncState, PRNGKey, Batch], ] FuncOutputs = Union[ - chex.Array, - Tuple[chex.Array, utils.FuncState], - Tuple[chex.Array, utils.FuncAux], - Tuple[chex.Array, Tuple[utils.FuncState, utils.FuncAux]], + Array, + Tuple[Array, FuncState], + Tuple[Array, FuncAux], + Tuple[Array, Tuple[FuncState, FuncAux]], ] ValueFunc = Callable[..., FuncOutputs] -ValueAndGradFunc = Callable[..., Tuple[FuncOutputs, utils.Params]] +ValueAndGradFunc = Callable[..., Tuple[FuncOutputs, Params]] ReturnWithFuncState = Tuple[ - utils.Params, OptimizerState, utils.FuncState, Mapping[str, chex.Array] + Params, OptimizerState, FuncState, Dict[str, Array] ] ReturnWithoutFuncState = Tuple[ - utils.Params, OptimizerState, Mapping[str, chex.Array] + Params, OptimizerState, Dict[str, Array] ] ReturnEither = Union[ReturnWithFuncState, ReturnWithoutFuncState] @@ -55,7 +63,7 @@ class Optimizer(utils.WithStagedMethods): """The K-FAC optimizer.""" @utils.pytree_dataclass - class State(utils.State): + class State(Generic[Params], utils.State): r"""Persistent state of the optimizer. Attributes: @@ -67,16 +75,16 @@ class State(utils.State): data_seen: The number of training cases that the optimizer has processed. step_counter: An integer giving the current step number :math:`t`. """ - velocities: utils.Params + velocities: Params estimator_state: curvature_estimator.BlockDiagonalCurvature.State - damping: Optional[chex.Array] - data_seen: chex.Numeric - step_counter: chex.Numeric + damping: Optional[Array] + data_seen: Numeric + step_counter: Numeric def __init__( self, value_and_grad_func: ValueAndGradFunc, - l2_reg: chex.Numeric, + l2_reg: Numeric, value_func_has_aux: bool = False, value_func_has_state: bool = False, value_func_has_rng: bool = False, @@ -86,30 +94,30 @@ def __init__( momentum_schedule: Optional[ScheduleType] = None, use_adaptive_damping: bool = False, damping_schedule: Optional[ScheduleType] = None, - initial_damping: Optional[chex.Numeric] = None, - min_damping: chex.Numeric = 1e-8, - max_damping: chex.Numeric = jnp.inf, + initial_damping: Optional[Numeric] = None, + min_damping: Numeric = 1e-8, + max_damping: Numeric = jnp.inf, include_damping_in_quad_change: bool = False, damping_adaptation_interval: int = 5, - damping_adaptation_decay: chex.Numeric = 0.9, - damping_lower_threshold: chex.Numeric = 0.25, - damping_upper_threshold: chex.Numeric = 0.75, + damping_adaptation_decay: Numeric = 0.9, + damping_lower_threshold: Numeric = 0.25, + damping_upper_threshold: Numeric = 0.75, always_use_exact_qmodel_for_damping_adjustment: bool = False, - norm_constraint: Optional[chex.Numeric] = None, + norm_constraint: Optional[Numeric] = None, num_burnin_steps: int = 10, estimation_mode: str = "fisher_gradients", - curvature_ema: chex.Numeric = 0.95, + curvature_ema: Numeric = 0.95, inverse_update_period: int = 5, use_exact_inverses: bool = False, - batch_process_func: Optional[Callable[[utils.Batch], utils.Batch]] = None, + batch_process_func: Optional[Callable[[Batch], Batch]] = None, register_only_generic: bool = False, patterns_to_skip: Sequence[str] = (), - auto_register_kwargs: Optional[Mapping[str, Any]] = None, + auto_register_kwargs: Optional[Dict[str, Any]] = None, layer_tag_to_block_ctor: - Optional[Mapping[str, curvature_estimator.CurvatureBlockCtor]] = None, + Optional[Dict[str, curvature_estimator.CurvatureBlockCtor]] = None, multi_device: bool = False, debug: bool = False, - batch_size_extractor: Callable[[utils.Batch], chex.Numeric] = + batch_size_extractor: Callable[[Batch], Numeric] = utils.default_batch_size_extractor, pmap_axis_name: str = "kfac_axis", forbid_setting_attributes_after_finalize: bool = True, @@ -349,7 +357,7 @@ def __init__( if momentum_schedule is not None: - def schedule_with_first_step_zero(global_step: chex.Array) -> chex.Array: + def schedule_with_first_step_zero(global_step: Array) -> Array: value = momentum_schedule(global_step) check = jnp.equal(global_step, 0) return check * jnp.zeros_like(value) + (1 - check) * value @@ -416,7 +424,7 @@ def num_burnin_steps(self) -> int: return self._num_burnin_steps @property - def l2_reg(self) -> chex.Array: + def l2_reg(self) -> Array: """The weight of the additional diagonal term added to the curvature.""" return self._l2_reg @@ -426,7 +434,7 @@ def estimator(self) -> curvature_estimator.BlockDiagonalCurvature: return self._estimator @property - def damping_decay_factor(self) -> chex.Numeric: + def damping_decay_factor(self) -> Numeric: """How fast to decay the damping, when using damping adaptation.""" return self._damping_adaptation_decay ** self._damping_adaptation_interval @@ -447,30 +455,30 @@ def _approx_powers_to_cache(self) -> Optional[Union[int, Sequence[int]]]: def should_update_damping( self, state: "Optimizer.State", - ) -> chex.Array: + ) -> Array: """Whether at the current step the optimizer should update the damping.""" return (state.step_counter + 1) % self._damping_adaptation_interval == 0 @functools.partial(utils.staged, static_argnums=1) def _rng_split( self, - rng: chex.PRNGKey, + rng: PRNGKey, num: int, - ) -> Tuple[chex.Array, ...]: + ) -> Tuple[Array, ...]: """Splits the ``rng`` key.""" return tuple(jax.random.split(rng, num)) @utils.auto_scope_method - def compute_loss_value(self, func_args: FuncArgsVariants) -> chex.Array: + def compute_loss_value(self, func_args: FuncArgsVariants) -> Array: """Computes the value of the loss function being optimized.""" return self._value_func(*func_args) def verify_args_and_get_step_counter( self, - step_counter: chex.Array, - learning_rate: Optional[chex.Array] = None, - momentum: Optional[chex.Array] = None, - damping: Optional[chex.Array] = None, + step_counter: Array, + learning_rate: Optional[Array] = None, + momentum: Optional[Array] = None, + damping: Optional[Array] = None, global_step_int: Optional[int] = None, ) -> int: """Verifies that the arguments passed to the step function are correct.""" @@ -529,11 +537,11 @@ def verify_args_and_get_step_counter( @utils.staged def _setup_state_and_schedules( self, - learning_rate: Optional[chex.Array], - momentum: Optional[chex.Array], - damping: Optional[chex.Array], - step_counter: chex.Array - ) -> Tuple[Optional[chex.Array], Optional[chex.Array], chex.Array]: + learning_rate: Optional[Array], + momentum: Optional[Array], + damping: Optional[Array], + step_counter: Array + ) -> Tuple[Optional[Array], Optional[Array], Array]: """Helper function for setting up learning rate, momentum and damping.""" # Compute schedules if applicable @@ -556,11 +564,11 @@ def _setup_state_and_schedules( def _setup_func_args_and_rng( self, - params: utils.Params, - rng: chex.PRNGKey, - batch: utils.Batch, - func_state: Optional[utils.FuncState], - ) -> Tuple[FuncArgsVariants, chex.Array]: + params: Params, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState], + ) -> Tuple[FuncArgsVariants, Array]: """Helper function for setting up the model function arguments correctly.""" # Preprocess the batch and construct correctly the function arguments @@ -587,9 +595,9 @@ def _update_estimator_curvature( self, estimator_state: curvature_estimator.BlockDiagonalCurvature.State, func_args: FuncArgsVariants, - rng: chex.PRNGKey, - ema_old: chex.Numeric, - ema_new: chex.Numeric, + rng: PRNGKey, + ema_old: Numeric, + ema_new: Numeric, ) -> curvature_estimator.BlockDiagonalCurvature.State: """Updates the curvature estimator state.""" @@ -608,7 +616,7 @@ def _update_estimator_curvature( def _compute_loss_and_grads( self, func_args: FuncArgsVariants, - ) -> Tuple[chex.Array, utils.Params, utils.FuncState, utils.FuncAux]: + ) -> Tuple[Array, Params, FuncState, FuncAux]: """Computes the model loss value and its gradients.""" out, grads = self._value_and_grad_func(*func_args) @@ -621,7 +629,7 @@ def _compute_loss_and_grads( def _maybe_update_inverse_cache( self, state: "Optimizer.State", - damping: chex.Array, + damping: Array, ) -> "Optimizer.State": """Updates the estimator state cache if it is the right iteration.""" @@ -649,10 +657,10 @@ def _maybe_update_inverse_cache( def _compute_preconditioned_gradient( self, state: "Optimizer.State", - grads: utils.Params, - coefficient: Optional[chex.Array], - damping: chex.Array, - ) -> Tuple[utils.Params, Optional[chex.Array]]: + grads: Params, + coefficient: Optional[Array], + damping: Array, + ) -> Tuple[Params, Optional[Array]]: """Computes the preconditioned gradient, maybe applying norm-constraint.""" preconditioned_grads = self.estimator.multiply_inverse( @@ -685,11 +693,11 @@ def _compute_preconditioned_gradient( def _compute_quad_change_for_damping( self, state: "Optimizer.State", - delta: utils.Params, - grads: utils.Params, - damping: chex.Array, + delta: Params, + grads: Params, + damping: Array, func_args: FuncArgsVariants, - ) -> chex.Array: + ) -> Array: """The quadratic model change, when lr and momentum are non-adaptive.""" assert not (self._use_adaptive_learning_rate or self._use_adaptive_momentum) @@ -701,18 +709,18 @@ def _compute_quad_change_for_damping( quad_model = self.compute_approx_quad_model(state, [delta], grads) w = jnp.ones([]) - return self._solve_quad_model(quad_model, damping, [delta], [w])[1] # pytype: disable=bad-return-type # numpy-scalars + return self._solve_quad_model(quad_model, damping, [delta], [w])[1] def _coefficients_and_quad_change( self, state: "Optimizer.State", - vectors: Sequence[utils.Params], - grads: utils.Params, - learning_rate: Optional[chex.Array], - momentum: Optional[chex.Array], - damping: chex.Array, + vectors: Sequence[Params], + grads: Params, + learning_rate: Optional[Array], + momentum: Optional[Array], + damping: Array, func_args: Optional[FuncArgsVariants] = None, - ) -> Tuple[Tuple[chex.Array, ...], Optional[chex.Array]]: + ) -> Tuple[Tuple[Optional[Array], Optional[Array]], Array]: """The correct update coefficients and corresponding quadratic change.""" # Compute the coefficients of the update vectors @@ -726,7 +734,7 @@ def _coefficients_and_quad_change( quad_model = self.compute_exact_quad_model(vectors, grads, func_args) - return self._solve_quad_model(quad_model, damping, vectors, coefficients) # pytype: disable=bad-return-type # numpy-scalars + return self._solve_quad_model(quad_model, damping, vectors, coefficients) else: assert all(c is not None for c in coefficients) @@ -748,11 +756,11 @@ def _coefficients_and_quad_change( @utils.auto_scope_method def _update_damping( self, - old_damping: chex.Array, - old_loss: chex.Array, - quad_change: chex.Array, + old_damping: Array, + old_loss: Array, + quad_change: Array, new_func_args: FuncArgsVariants, - ) -> Tuple[chex.Array, chex.Array, chex.Array]: + ) -> Tuple[Array, Array, Array]: """Updates the damping parameter.""" new_loss = self.compute_loss_value(new_func_args) @@ -768,10 +776,10 @@ def _update_damping( @utils.staged def _init( self, - params: utils.Params, - rng: chex.PRNGKey, - batch: utils.Batch, - func_state: Optional[utils.FuncState] = None, + params: Params, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState] = None, ) -> "Optimizer.State": """A staged function to initialize the optimizer state .""" @@ -799,10 +807,10 @@ def _init( def init( self, - params: utils.Params, - rng: chex.PRNGKey, - batch: utils.Batch, - func_state: Optional[utils.FuncState] = None, + params: Params, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState] = None, ) -> "Optimizer.State": """Initializes the optimizer and returns the appropriate optimizer state.""" @@ -814,11 +822,11 @@ def init( @functools.partial(utils.staged, donate_argnums=[1, 3, 5]) def _burnin( self, - params: utils.Params, + params: Params, state: "Optimizer.State", - rng: chex.Array, - batch: utils.Batch, - func_state: Optional[utils.FuncState], + rng: Array, + batch: Batch, + func_state: Optional[FuncState], accumulator: utils.MultiChunkAccumulator ) -> Tuple["Optimizer.State", utils.MultiChunkAccumulator]: """A single burnin step, updating only the curvature estimate.""" @@ -846,12 +854,12 @@ def _burnin( def burnin( self, num_steps: int, - params: utils.Params, + params: Params, state: "Optimizer.State", - rng: chex.PRNGKey, - data_iterator: Iterator[utils.Batch], - func_state: Optional[utils.FuncState] = None, - ) -> Tuple["Optimizer.State", Optional[utils.FuncState]]: + rng: PRNGKey, + data_iterator: Iterator[Batch], + func_state: Optional[FuncState] = None, + ) -> Tuple["Optimizer.State", Optional[FuncState]]: """Runs all burnin steps required.""" if num_steps > 0: @@ -874,14 +882,14 @@ def burnin( @utils.auto_scope_method def _step( self, - params: utils.Params, + params: Params, state: "Optimizer.State", - rng: chex.Array, - batch: utils.Batch, - func_state: Optional[utils.FuncState], - learning_rate: Optional[chex.Array], - momentum: Optional[chex.Array], - damping: Optional[chex.Array] + rng: Array, + batch: Batch, + func_state: Optional[FuncState], + learning_rate: Optional[Array], + momentum: Optional[Array], + damping: Optional[Array] )-> ReturnEither: """A single full step of the optimizer.""" @@ -1025,15 +1033,15 @@ def _step( def step( self, - params: utils.Params, + params: Params, state: "Optimizer.State", - rng: chex.PRNGKey, - data_iterator: Optional[Iterator[utils.Batch]] = None, - batch: Optional[utils.Batch] = None, - func_state: Optional[utils.FuncState] = None, - learning_rate: Optional[chex.Array] = None, - momentum: Optional[chex.Array] = None, - damping: Optional[chex.Array] = None, + rng: PRNGKey, + data_iterator: Optional[Iterator[Batch]] = None, + batch: Optional[Batch] = None, + func_state: Optional[FuncState] = None, + learning_rate: Optional[Array] = None, + momentum: Optional[Array] = None, + damping: Optional[Array] = None, global_step_int: Optional[int] = None )-> ReturnEither: """Performs a single update step using the optimizer. @@ -1105,8 +1113,8 @@ def step( def compute_l2_quad_matrix( self, - vectors: Sequence[utils.Params] - ) -> chex.Array: + vectors: Sequence[Params] + ) -> Array: """Computes the matrix corresponding to the prior/regularizer. Args: @@ -1121,10 +1129,10 @@ def compute_l2_quad_matrix( @utils.auto_scope_method def compute_exact_quad_model( self, - vectors: Sequence[utils.Params], - grads: utils.Params, + vectors: Sequence[Params], + grads: Params, func_args: Optional[FuncArgsVariants] = None, - ) -> Tuple[chex.Array, chex.Array, chex.Array]: + ) -> Tuple[Array, Array, Array]: """Computes the components of the exact quadratic model.""" if func_args is None: raise ValueError("When you have not provided `c_factor_v` you must " @@ -1148,9 +1156,9 @@ def compute_exact_quad_model( def compute_approx_quad_model( self, state: "Optimizer.State", - vectors: Sequence[utils.Params], - grads: utils.Params, - ) -> Tuple[chex.Array, chex.Array, chex.Array]: + vectors: Sequence[Params], + grads: Params, + ) -> Tuple[Array, Array, Array]: """Computes the components of the approximate quadratic model.""" # v_i^T C v_j @@ -1172,11 +1180,11 @@ def c_times_v(v): def compute_quadratic_model_value( self, - a: chex.Array, - a_damped: chex.Array, - b: chex.Array, - w: chex.Array, - ) -> chex.Array: + a: Array, + a_damped: Array, + b: Array, + w: Array, + ) -> Array: """Computes the quadratic model value from the inputs provided.""" a_final = a_damped if self._include_damping_in_quad_change else a @@ -1186,11 +1194,11 @@ def compute_quadratic_model_value( @utils.staged def _solve_quad_model( self, - quad_model_parameters: Tuple[chex.Array, chex.Array, chex.Array], - damping: chex.Array, - vectors: Sequence[utils.Params], - fixed_coefficients: Optional[Sequence[Union[chex.Array, None]]] = None, - ) -> Tuple[Tuple[chex.Array, ...], chex.Array]: + quad_model_parameters: Tuple[Array, Array, Array], + damping: Array, + vectors: Sequence[Params], + fixed_coefficients: Optional[Sequence[Union[Numeric, None]]] = None, + ) -> Tuple[Tuple[Optional[Array], ...], Array]: """Solves for the optimal learning rate and momentum of the quadratic model. The quadratic model is represented as: @@ -1239,7 +1247,8 @@ def _solve_quad_model( # since the convention everywhere else is to sync quantities immediately # after they are first computed). A, A_damped, b = utils.pmean_if_pmap((A, A_damped, b), self.pmap_axis_name) - # pylint: enable=invalid-name + # This needs explicit annotation + A_damped: Array if all(c is None for c in fixed_coefficients): # Adapt all coefficients @@ -1277,8 +1286,9 @@ def _solve_quad_model( w[1 - index] = jnp.asarray([fixed_coefficients[1 - index]]) b_extra = A_damped[1 - index, index: index + 1] * w[1 - index] - A_solve = A_damped[index: index + 1, index: index + 1] # pylint: disable=invalid-name + A_solve = A_damped[index: index + 1, index: index + 1] b_solve = b[index: index + 1] + b_extra + # pylint: enable=invalid-name w[index] = - b_solve / A_solve[0] w = jnp.concatenate(w, axis=0) @@ -1293,11 +1303,11 @@ def _solve_quad_model( @utils.staged def _compute_new_damping_and_rho( self, - old_loss: chex.Array, - new_loss: chex.Array, - quad_change: chex.Array, - current_damping: chex.Array, - ) -> Tuple[chex.Array, chex.Array]: + old_loss: Array, + new_loss: Array, + quad_change: Array, + current_damping: Array, + ) -> Tuple[Array, Array]: """Computes the reduction ratio and the updated value of the damping.""" # Reduction ratio @@ -1321,7 +1331,7 @@ def _compute_new_damping_and_rho( def weighted_sum_of_objects( self, objects: Sequence[utils.PyTree], - coefficients: Sequence[chex.Numeric], + coefficients: Sequence[Numeric], ) -> utils.PyTree: """Returns the weighted sum of the objects in the sequence.""" return utils.weighted_sum_of_objects(objects, coefficients) @@ -1342,7 +1352,7 @@ def convert_value_and_grad_to_value_func( Returns: A function that returns only the loss value. """ - def value_func(*args) -> chex.Array: + def value_func(*args) -> Array: out, _ = value_and_grad_func(*args) return out[0] if has_aux else out @@ -1350,10 +1360,10 @@ def value_func(*args) -> chex.Array: def make_func_args( - params: utils.Params, - func_state: Optional[utils.FuncState], - rng: Optional[chex.PRNGKey], - batch: utils.Batch, + params: Params, + func_state: Optional[FuncState], + rng: Optional[PRNGKey], + batch: Batch, has_state: bool, has_rng: bool, ) -> FuncArgsVariants: @@ -1399,7 +1409,7 @@ def extract_func_outputs( raw_outputs: FuncOutputs, has_aux: bool, has_state: bool, -) -> Tuple[chex.Array, Optional[utils.FuncState], Optional[utils.FuncAux]]: +) -> Tuple[Array, Optional[FuncState], Optional[FuncAux]]: """Converts the raw output of the model function into loss,func_state and aux. Args: diff --git a/kfac_jax/_src/patches_second_moment.py b/kfac_jax/_src/patches_second_moment.py index db32add..2e42ee1 100644 --- a/kfac_jax/_src/patches_second_moment.py +++ b/kfac_jax/_src/patches_second_moment.py @@ -13,9 +13,8 @@ # limitations under the License. """K-FAC optimized functions for patches second moment(PSM) computation.""" import functools -from typing import List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Optional, Sequence, TypeVar, Tuple, Union, List -import chex import jax from jax import interpreters from jax import lax @@ -25,8 +24,10 @@ # Types for annotation T = TypeVar("T") +Array = utils.Array +Shape = utils.Shape TracedType = interpreters.partial_eval.DynamicJaxprTracer -DimNumbers = Tuple[chex.Shape, chex.Shape, chex.Shape] +DimNumbers = Tuple[Shape, Shape, Shape] PaddingVariants = Union[str, int, Sequence[int], Sequence[Tuple[int, int]]] # Special global variables @@ -76,15 +77,15 @@ def spatial_axes(self) -> Tuple[int]: """Returns the indices of the spatial axes.""" return self.order[2:] - def get_n(self, shape: chex.Shape) -> int: + def get_n(self, shape: Shape) -> int: """Returns the batch size of the given shape, under this spec layout.""" return shape[self.n_axis] - def get_c(self, shape: chex.Shape) -> int: + def get_c(self, shape: Shape) -> int: """Returns the channel size of the given shape, under this spec layout.""" return shape[self.c_axis] - def get_spatial(self, shape: chex.Shape) -> Tuple[int, ...]: + def get_spatial(self, shape: Shape) -> Tuple[int, ...]: """Returns the spatial sizes of the given shape, under this spec layout.""" return tuple(shape[i] for i in self.spatial_axes) @@ -122,10 +123,10 @@ def change_nhwc_to_ihwo(self) -> "_ConvSpec": def _slice_array( - array: chex.Array, + array: Array, indices: Sequence[Union[int, TracedType]], sizes: Sequence[int], -) -> chex.Array: +) -> Array: """Takes a slice from the array provided.""" if any(isinstance(x, TracedType) for x in indices): # Any of the indices are dynamic values. @@ -137,11 +138,11 @@ def _slice_array( def _output_spatial_shape( - inputs_spatial_shape: chex.Shape, - kernel_spatial_shape: chex.Shape, - spatial_strides: chex.Shape, + inputs_spatial_shape: Shape, + kernel_spatial_shape: Shape, + spatial_strides: Shape, padding: Union[str, Sequence[Tuple[int, int]]], -) -> chex.Shape: +) -> Shape: """Returns the output spatial shape of the corresponding convolution.""" if isinstance(padding, str): if padding.lower() == "valid": @@ -161,9 +162,9 @@ def _output_spatial_shape( def _normalize_padding( - inputs_spatial_shape: chex.Shape, - kernel_spatial_shape: chex.Shape, - spatial_strides: chex.Shape, + inputs_spatial_shape: Shape, + kernel_spatial_shape: Shape, + spatial_strides: Shape, padding: PaddingVariants, ) -> Tuple[Tuple[int, int], ...]: """Returns the padding as a tuple of pairs of integers.""" @@ -197,8 +198,8 @@ def _normalize_padding( def _normalize_strides( - kernel_spatial_shape: chex.Shape, - strides: Union[int, chex.Shape], + kernel_spatial_shape: Shape, + strides: Union[int, Shape], ) -> Tuple[int, ...]: """Returns the strides as a tuple of integers.""" n = len(kernel_spatial_shape) @@ -227,9 +228,9 @@ def _data_format_to_dim_numbers( def _parse_simple_args( - inputs_shape: chex.Shape, - kernel_spatial_shape: Union[int, chex.Shape], - strides: Union[int, chex.Shape] = 1, + inputs_shape: Shape, + kernel_spatial_shape: Union[int, Shape], + strides: Union[int, Shape] = 1, padding: PaddingVariants = "VALID", data_format: Optional[str] = "NHWC", dim_numbers: Optional[Union[DimNumbers, lax.ConvDimensionNumbers]] = None, @@ -237,7 +238,7 @@ def _parse_simple_args( Tuple[int, ...], Tuple[int, ...], Tuple[Tuple[int, int], ...], - lax.ConvDimensionNumbers + lax.ConvDimensionNumbers, ]: """Parses all convolutional arguments to a single unified format. @@ -308,9 +309,9 @@ def _parse_simple_args( def _num_conv_locations_full_spec( - input_spatial_shape: chex.Shape, - kernel_spatial_shape: chex.Shape, - spatial_strides: chex.Shape, + input_spatial_shape: Shape, + kernel_spatial_shape: Shape, + spatial_strides: Shape, spatial_padding: Sequence[Tuple[int, int]], ) -> int: """The number of convolution locations from the unified spec for arguments.""" @@ -339,9 +340,9 @@ def _num_conv_locations_full_spec( def num_conv_locations( - inputs_spatial_shape: chex.Shape, - kernel_spatial_shape: Union[int, chex.Shape], - spatial_strides: Union[int, chex.Shape], + inputs_spatial_shape: Shape, + kernel_spatial_shape: Union[int, Shape], + spatial_strides: Union[int, Shape], spatial_padding: Union[str, int, Sequence[Tuple[int, int]]], ) -> int: """Returns the number of convolution locations for the provided shapes.""" @@ -360,9 +361,9 @@ def num_conv_locations( @utils.auto_scope_function def _the_conv4d( - lhs: chex.Array, + lhs: Array, lhs_spec: _ConvSpec, - rhs: chex.Array, + rhs: Array, rhs_spec: _ConvSpec, pad_h: int, pad_w: int, @@ -370,7 +371,7 @@ def _the_conv4d( stride_w: int, per_channel: bool = False, precision: Optional[jax.lax.Precision] = None, -) -> chex.Array: +) -> Array: """Performs a special conv4d or conv2d based on the global flag.""" assert len(rhs_spec) == 6 if get_use_4d_convolution_in_psm_loop(): @@ -482,9 +483,9 @@ def single_conv(x, y): def _validate_inputs_lengths( - inputs: chex.Array, - kernel_spatial_shape: chex.Shape, - strides: chex.Shape, + inputs: Array, + kernel_spatial_shape: Shape, + strides: Shape, padding: Tuple[Tuple[int, int], ...], ) -> None: """Checks that the provided arguments are valid.""" @@ -518,9 +519,9 @@ def _validate_inputs_lengths( "batch_group_count", "unroll_loop", "precision")) @utils.auto_scope_function def patches_moments_explicit( - inputs: chex.Array, - kernel_spatial_shape: Union[int, chex.Shape], - strides: Union[int, chex.Shape] = 1, + inputs: Array, + kernel_spatial_shape: Union[int, Shape], + strides: Union[int, Shape] = 1, padding: PaddingVariants = "VALID", data_format: Optional[str] = "NHWC", dim_numbers: Optional[Union[DimNumbers, lax.ConvDimensionNumbers]] = None, @@ -530,8 +531,8 @@ def patches_moments_explicit( batch_group_count: int = 1, unroll_loop: bool = False, precision: Optional[jax.lax.Precision] = None, - weighting_array: Optional[chex.Array] = None, -) -> Tuple[chex.Array, chex.Array]: + weighting_array: Optional[Array] = None, +) -> Tuple[Array, Array]: """The exact same functionality as :func:`~patches_moments`, but explicitly extracts the patches via :func:`jax.lax.conv_general_dilated_patches`, potentially having a higher memory usage.""" kernel_spatial_shape, strides, padding, dim_numbers = _parse_simple_args( inputs.shape, kernel_spatial_shape, padding=padding, strides=strides, @@ -627,7 +628,7 @@ def general_loop_body(i, image): else: wf_n = weighting_array[in_spec.n_axis] wf_spatial = [weighting_array.shape[a] for a in in_spec.spatial_axes] - wf_sizes = in_spec.create_shape(wf_n, 1, *wf_spatial) + wf_sizes = in_spec.create_shape(wf_n, jnp.ones([]), *wf_spatial) wf_i = _slice_array(weighting_array, index, wf_sizes) else: wf_i = None @@ -683,9 +684,9 @@ def loop_body(args): "batch_group_count", "unroll_loop", "precision")) @utils.auto_scope_function def patches_moments( - inputs: chex.Array, - kernel_spatial_shape: Union[int, chex.Shape], - strides: Union[int, chex.Shape] = 1, + inputs: Array, + kernel_spatial_shape: Union[int, Shape], + strides: Union[int, Shape] = 1, padding: PaddingVariants = "VALID", data_format: Optional[str] = "NHWC", dim_numbers: Optional[Union[DimNumbers, lax.ConvDimensionNumbers]] = None, @@ -695,8 +696,8 @@ def patches_moments( batch_group_count: int = 1, unroll_loop: bool = False, precision: Optional[jax.lax.Precision] = None, - weighting_array: Optional[chex.Array] = None, -) -> Tuple[chex.Array, chex.Array]: + weighting_array: Optional[Array] = None, +) -> Tuple[Array, Array]: """Computes the first and second moment of the convolutional patches. Since the code is written to support arbitrary convolution data formats, e.g. diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 7cb7c57..3470cf2 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -16,10 +16,9 @@ import functools import itertools import pprint -from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, TypeVar, Union +from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar, Tuple, Union, Dict, Set from absl import logging -import chex import immutabledict import jax import jax.numpy as jnp @@ -30,6 +29,8 @@ HIGHER_ORDER_NAMES = ("cond", "while", "scan", "xla_call", "xla_pmap") # Types for annotation +Array = utils.Array +PyTreeDef = utils.PyTreeDef Var = jax.core.Var Vars = Sequence[Var] Jaxpr = jax.core.Jaxpr @@ -42,7 +43,7 @@ EquivalenceFunction = Callable[[JaxprEqn, JaxprEqn], bool] MakeVarFunc = Callable[[jax.core.AbstractValue], Var] VarProcessor = Callable[[Vars, MakeVarFunc], Tuple[Vars, JaxprEqns]] -PatternComputeFunc = Callable[[chex.Array, Sequence[chex.Array]], chex.Array] +PatternComputeFunc = Callable[[Array, Sequence[Array]], Array] ParameterExtractorFunc = Callable[[JaxprEqns], Mapping[str, Any]] TagCtor = Callable[[Vars, Vars, JaxprEqns, MakeVarFunc], JaxprEqn] @@ -224,9 +225,9 @@ class JaxprGraph: """ name: str closed_jaxpr: ClosedJaxpr - params_tree: chex.PyTreeDef + params_tree: PyTreeDef params_vars: Vars - out_tree: chex.PyTreeDef + out_tree: PyTreeDef tag_ctor: Optional[TagCtor] # Until we stop supporting Python 3.7 we can't use @functools.cached_property, # so we set these attributes in __post_init__ @@ -318,7 +319,7 @@ def sub_graph_eqns(self, root_vars: Vars, leaf_vars: Vars) -> JaxprEqns: def make_jax_graph( func: utils.Func, - func_args: Sequence[Any], + func_args: utils.FuncArgs, params_index: Union[int, Sequence[int]], name: str, compute_only_loss_tags: bool, @@ -789,9 +790,9 @@ def find_layer_tags_and_patterns( def read_env( - env: Mapping[Var, chex.Array], - var: Union[jax.core.Literal, Var, Sequence[Var]], -) -> Union[float, chex.Array, Sequence[chex.Array]]: + env: Mapping[Var, Array], + var: Union[jax.core.Literal, Vars], +) -> Union[float, Array, Sequence[Array]]: """Reads from the variable-to-array environment during tracing.""" if isinstance(var, (list, tuple)): return jax.tree_util.tree_map(lambda x: read_env(env, x), var) @@ -805,10 +806,10 @@ def read_env( def write_env( - env: MutableMapping[Var, chex.Array], - var: Union[Var, List[Var]], - val: Union[chex.Array, List[chex.Array]], -) -> None: + env: Dict[Var, Array], + var: Union[Var, Vars], + val: Union[Array, Sequence[Array]], +): """Writes to the variable-to-array environment during tracing.""" if isinstance(var, tuple): raise NotImplementedError() @@ -953,7 +954,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J: # |___/ -def _dense(x: chex.Array, params: Sequence[chex.Array]) -> chex.Array: +def _dense(x: Array, params: Sequence[Array]) -> Array: """Example of a dense layer function.""" w, *opt_b = params y = jnp.matmul(x, w) @@ -987,7 +988,7 @@ def _make_dense_pattern( ) -def _conv2d(x: chex.Array, params: Sequence[chex.Array]) -> chex.Array: +def _conv2d(x: Array, params: Sequence[Array]) -> Array: """Example of a conv2d layer function.""" w = params[0] y = jax.lax.conv_general_dilated( @@ -1029,11 +1030,11 @@ def _make_conv2d_pattern( def _scale_and_shift( - x: chex.Array, - params: Sequence[chex.Array], + x: Array, + params: Sequence[Array], has_scale: bool, has_shift: bool, -) -> chex.Array: +) -> Array: """Example of a scale and shift function.""" if has_scale and has_shift: scale, shift = params @@ -1080,11 +1081,11 @@ def _make_scale_and_shift_pattern( def _normalization_haiku( - inputs: Sequence[chex.Array], - params: Sequence[chex.Array], + inputs: Sequence[Array], + params: Sequence[Array], has_scale: bool, has_shift: bool, -) -> chex.Array: +) -> Array: """Example of normalization as is defined in Haiku.""" if len(params) not in (1, 2): raise ValueError("The inputs to the `normalization_haiku` computation must " diff --git a/kfac_jax/_src/tracer.py b/kfac_jax/_src/tracer.py index 411cde8..fda0f95 100644 --- a/kfac_jax/_src/tracer.py +++ b/kfac_jax/_src/tracer.py @@ -13,36 +13,38 @@ # limitations under the License. """K-FAC tracing functionality for functions needed for curvature estimation.""" import functools -from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, Sequence, TypeVar, Tuple, Union, Dict, List -import chex import jax import jax.numpy as jnp from kfac_jax._src import layers_and_loss_tags as tags from kfac_jax._src import loss_functions from kfac_jax._src import tag_graph_matcher as tgm from kfac_jax._src import utils +from typing_extensions import TypeAlias # Types for annotations +Array = utils.Array +Shape = utils.Shape +Params = utils.Params +FuncArgs = utils.FuncArgs +FuncOuts = utils.FuncOuts +Var = jax.core.Var + T = TypeVar("T") -J = TypeVar("J", jax.core.Jaxpr, jax.core.ClosedJaxpr) +# J = TypeVar("J", jax.core.Jaxpr, jax.core.ClosedJaxpr) +ProcJaxpr: TypeAlias = "ProcessedJaxpr" TaggedFunction = Callable[..., Tuple[loss_functions.LossFunction, ...]] FuncWithTags = Callable[..., Any] -LossTagInputs = Tuple[chex.Array, ...] -LayerTagInputs = Tuple[chex.Array, ...] -FunctionTransformation = Union[ - Callable[["ProcessedJaxpr", utils.FuncArgs], T], - Callable[["ProcessedJaxpr", utils.FuncArgs, utils.Params], T], -] -TransformedFunction = Union[ - Callable[[utils.FuncArgs], T], - Callable[[utils.FuncArgs, bool], Union[T, "ProcessedJaxpr"]], - Callable[[utils.FuncArgs, utils.Params], T], - Callable[[utils.FuncArgs, utils.Params, bool], Union[T, "ProcessedJaxpr"]], -] +LossTagInputs = Tuple[Array, ...] +LayerTagInputs = Tuple[Array, ...] + +FunctionTransformation = Callable[..., Union[ProcJaxpr, T]] +TransformedFunction = Callable[..., Union[ProcJaxpr, T]] + LossTagsVjp = Tuple[ Tuple[loss_functions.LossFunction, ...], - Callable[[Sequence[LossTagInputs]], utils.Params] + Callable[[Sequence[LossTagInputs]], Params] ] LossTagsJvp = Tuple[ Tuple[loss_functions.LossFunction, ...], @@ -50,20 +52,20 @@ ] LayerTagVjp = Tuple[ Tuple[loss_functions.LossFunction, ...], - Callable[[Tuple[LossTagInputs, ...]], Tuple[Dict[str, chex.Array], ...]] + Callable[[Tuple[LossTagInputs, ...]], Tuple[Dict[str, Array], ...]] ] JaxprOrClosedJaxpr = Union[jax.core.Jaxpr, jax.core.ClosedJaxpr] -def shape_and_type(x: chex.Array) -> Tuple[chex.Shape, chex.ArrayDType]: +def shape_and_type(x: Array) -> Tuple[Shape, jnp.dtype]: """Returns the shape and type of the given array.""" return x.shape, x.dtype def make_cache_key( - func_args: utils.FuncArgs, + func_args: FuncArgs, *args: Any -) -> Tuple[utils.PyTreeDef, Tuple[Tuple[chex.Shape, chex.ArrayDType], ...]]: +) -> Tuple[utils.PyTreeDef, Tuple[Tuple[Shape, jnp.dtype], ...]]: """Creates a key for caching Jax function arguments.""" args_flat, tree_structure = jax.tree_util.tree_flatten((func_args, args)) @@ -83,7 +85,7 @@ def extract_tags( def order_layer_tags( - params_vars_flat: Sequence[jax.core.Var], + params_vars_flat: Sequence[Var], layer_tags: Sequence[tags.LayerTagEqn], allow_left_out_params: bool = False, ) -> Tuple[Tuple[tags.LayerTagEqn, ...], Tuple[Tuple[int, ...], ...]]: @@ -183,22 +185,22 @@ def __init__( self.finalize() @property - def in_vars_flat(self) -> List[jax.core.Var]: + def in_vars_flat(self) -> List[Var]: """A flat list of all of the abstract input variables.""" return self.jaxpr.invars @property - def in_vars(self) -> utils.PyTree: + def in_vars(self) -> utils.PyTree[Var]: """The abstract input variables, as an un-flatten structure.""" return jax.tree_util.tree_unflatten(self.in_tree, self.in_vars_flat) @property - def params_vars(self) -> utils.PyTree: + def params_vars(self) -> utils.PyTree[Var]: """The abstract parameter variables, as an un-flatten structure.""" return self.in_vars[self.params_index] @property - def params_vars_flat(self) -> List[jax.core.Var]: + def params_vars_flat(self) -> List[Var]: """A flat list of all abstract parameter variables.""" return jax.tree_util.tree_leaves(self.params_vars) @@ -211,12 +213,12 @@ def params_tree(self) -> utils.PyTreeDef: def make_from_func( cls, func: utils.Func, - func_args: utils.FuncArgs, + func_args: FuncArgs, params_index: int = 0, auto_register_tags: bool = True, allow_left_out_params: bool = False, ** auto_registration_kwargs: Any, - ) -> "ProcessedJaxpr": + ) -> ProcJaxpr: """Constructs a :class:`~ProcessedJaxpr` from a the given function. Args: @@ -258,7 +260,7 @@ def make_from_func( allow_left_out_params=allow_left_out_params, ) - def __eq__(self, other: "ProcessedJaxpr") -> bool: + def __eq__(self, other: ProcJaxpr) -> bool: """Compares two ProcessedJaxpr instances by tree structure.""" # Verify whether input trees are equivalent @@ -338,50 +340,45 @@ def cached_transformation( @functools.wraps(transformation) def wrapped_transformation( - func_args: utils.FuncArgs, + func_args: FuncArgs, *args: Any, return_only_jaxpr: bool = False, - ) -> Union[ProcessedJaxpr, Any]: - + ) -> Union[ProcessedJaxpr, T]: # Construct a key and check cache for hits key = make_cache_key(func_args) jaxpr, f = cache.get(key, (None, None)) - if jaxpr is not None: - if return_only_jaxpr: - return jaxpr - else: - return f(func_args, *args) - - # Process the function - processed_jaxpr = ProcessedJaxpr.make_from_func( - func=func, - func_args=func_args, - params_index=params_index, - auto_register_tags=auto_register_tags, - allow_left_out_params=allow_left_out_params, - **auto_registration_kwargs - ) + if jaxpr is None: + assert f is None + # Process the function + jaxpr = ProcessedJaxpr.make_from_func( + func=func, + func_args=func_args, + params_index=params_index, + auto_register_tags=auto_register_tags, + allow_left_out_params=allow_left_out_params, + **auto_registration_kwargs + ) - if not allow_no_losses and not processed_jaxpr.loss_tags: - raise ValueError("No registered losses have been found during tracing.") + if not allow_no_losses and not jaxpr.loss_tags: + raise ValueError("No registered losses have been found during tracing.") - if cache and raise_error_on_diff_jaxpr: + if cache and raise_error_on_diff_jaxpr: - # If any previous `ProcessedJaxpr` exists verify that they are equivalent - ref_jaxpr, _ = cache[next(iter(cache))] + # If any previous `ProcessedJaxpr` exists verify that it is equivalent + ref_jaxpr, _ = cache[next(iter(cache))] - if ref_jaxpr != processed_jaxpr: - raise ValueError("The consecutive tracing of the provided function " - "yielded a non-equivalent `ProcessedJaxpr`.") + if ref_jaxpr != jaxpr: + raise ValueError("The consecutive tracing of the provided function " + "yielded a non-equivalent `ProcessedJaxpr`.") - transformed = functools.partial(transformation, processed_jaxpr) - cache[key] = (processed_jaxpr, transformed) + f = functools.partial(transformation, jaxpr) + cache[key] = (jaxpr, f) if return_only_jaxpr: - return processed_jaxpr + return jaxpr else: - return transformed(func_args, *args) + return f(func_args, *args) return wrapped_transformation @@ -390,10 +387,10 @@ def construct_compute_losses_inputs( jaxpr: jax.core.Jaxpr, consts: Sequence[Any], num_losses: int, - primal_func_args: utils.FuncArgs, + primal_func_args: FuncArgs, params_index: int ) -> Callable[ - [utils.Params], + [Params], Tuple[Tuple[LossTagInputs, ...], Tuple[LossTagInputs, ...]] ]: """Constructs a function that computes the inputs to all loss tags. @@ -420,7 +417,7 @@ def construct_compute_losses_inputs( """ def forward_compute_losses( - primal_params: utils.Params + primal_params: Params ) -> Tuple[Tuple[LossTagInputs, ...], Tuple[LossTagInputs, ...]]: """Computes and returns the inputs to the first ``num_losses`` loss tags.""" @@ -469,7 +466,7 @@ def forward_compute_losses( def _loss_tags_vjp( p_jaxpr: ProcessedJaxpr, - primal_func_args: utils.FuncArgs, + primal_func_args: FuncArgs, ) -> LossTagsVjp: """Computes a (backward-mode) vector-Jacobian product w.r.t. all loss tags. @@ -509,7 +506,7 @@ def _loss_tags_vjp( zero_tangents = jax.tree_util.tree_map(jnp.zeros_like, losses_inputs) - def losses_vjp_func(losses_tangents: Sequence[LossTagInputs]) -> utils.Params: + def losses_vjp_func(losses_tangents: Sequence[LossTagInputs]) -> Params: """Computes the vector-Jacobian product w.r.t. the parameters. Args: @@ -543,8 +540,8 @@ def losses_vjp_func(losses_tangents: Sequence[LossTagInputs]) -> utils.Params: def _loss_tags_jvp( p_jaxpr: ProcessedJaxpr, - primal_func_args: utils.FuncArgs, - params_tangents: utils.Params, + primal_func_args: FuncArgs, + params_tangents: Params, ) -> LossTagsJvp: """Computes a (forward-mode) Jacobian-vector product w.r.t. all loss tags. @@ -597,9 +594,9 @@ def _loss_tags_jvp( def _loss_tags_hvp( processed_jaxpr: ProcessedJaxpr, - primal_func_args: utils.FuncArgs, - params_tangents: utils.Params, -) -> Tuple[utils.Params, Tuple[loss_functions.LossFunction, ...]]: + primal_func_args: FuncArgs, + params_tangents: Params, +) -> Tuple[Params, Tuple[loss_functions.LossFunction, ...]]: """Computes a Hessian-vector product of the function w.r.t. all loss tags. The function takes as inputs the concrete values of the primals for the @@ -631,7 +628,7 @@ def _loss_tags_hvp( params_index=processed_jaxpr.params_index) def compute_losses( - param_primals: utils.Params + param_primals: Params ) -> Tuple[loss_functions.LossFunction, ...]: """Computes the sum of all losses as a scalar.""" @@ -640,7 +637,7 @@ def compute_losses( return tuple(tag.primitive.loss(*inputs, **tag.params) for tag, inputs in zip(processed_jaxpr.loss_tags, loss_inputs)) - def losses_sum(param_primals: utils.Params) -> chex.Array: + def losses_sum(param_primals: Params) -> Array: # This computes the sum of losses evaluated. Makes it easier because we can # now use jax.grad rather than jax.vjp for taking derivatives. return sum(jnp.sum(loss.evaluate()) for loss in @@ -655,7 +652,7 @@ def losses_sum(param_primals: utils.Params) -> chex.Array: def _layer_tag_vjp( processed_jaxpr: ProcessedJaxpr, - primal_func_args: utils.FuncArgs, + primal_func_args: FuncArgs, ) -> LayerTagVjp: """Computes primal values and tangents w.r.t. all layer tags. @@ -681,7 +678,7 @@ def _layer_tag_vjp( [tag.invars for tag in processed_jaxpr.layer_tags]) layer_input_vars = tuple(set(layer_vars_flat)) - def forward() -> Tuple[chex.Array, ...]: + def forward() -> Tuple[Array, ...]: """Computes the values of all inputs to all **layer** tags.""" own_func_args = primal_func_args @@ -714,7 +711,7 @@ def forward() -> Tuple[chex.Array, ...]: return read(layer_input_vars) def forward_aux( - aux: Mapping[jax.core.Var, chex.Array] + aux: Dict[Var, Array] ) -> Tuple[Tuple[LossTagInputs, ...], Tuple[LossTagInputs, ...]]: """Computes the inputs and kwargs of all **loss** tags. @@ -807,7 +804,7 @@ def write(var, val): def vjp_func( tangents: Tuple[LossTagInputs, ...] - ) -> Tuple[Dict[str, chex.Array], ...]: + ) -> Tuple[Dict[str, Array], ...]: """Computes a (reverse-mode) vector-Jacobian product w.r.t. all layer tags. Args: @@ -854,12 +851,10 @@ def vjp_func( return losses, vjp_func -# Pytype throws an error with output type annotation -# -> Callable[[utils.FuncArgs], LossTagsVjp] def loss_tags_vjp( func: utils.Func, params_index: int = 0, -) -> ...: +) -> TransformedFunction[LossTagsVjp]: """Creates a function for the vector-Jacobian product w.r.t. all loss tags. The returned function has a similar interface to :func:`jax.vjp`. It takes as @@ -878,7 +873,7 @@ def loss_tags_vjp( Returns: A function that computes the vector-Jacobian product with signature - `Callable[[utils.FuncArgs], LossTagsVjp]`. + `Callable[[FuncArgs], LossTagsVjp]`. """ # Note that this function is independent of any layer tags, hence we can avoid # calling the auto registration. @@ -892,8 +887,6 @@ def loss_tags_vjp( ) -# PyType throws an error with output type annotation: -# -> Callable[[utils.FuncArgs, utils.Params], LossTagsVjp] def loss_tags_jvp( func: utils.Func, params_index: int = 0, @@ -917,7 +910,7 @@ def loss_tags_jvp( Returns: A function that computes the Jacobian-vector product with signature - `Callable[[utils.FuncArgs, utils.Params], LossTagsVjp]`. + `Callable[[FuncArgs, Params], LossTagsVjp]`. """ # Note that this function is independent of any layer tags, hence we can avoid # calling the auto registration. @@ -931,8 +924,6 @@ def loss_tags_jvp( ) -# PyType throws an error with output type annotation: -# -> Callable[[utils.FuncArgs, utils.Params], LossTagsVjp] def loss_tags_hvp( func: utils.Func, params_index: int = 0, @@ -953,7 +944,7 @@ def loss_tags_hvp( Returns: A function that computes the Hessian-vector product and also returns all - losses, with signature `Callable[[utils.FuncArgs, utils.Params], + losses, with signature `Callable[[FuncArgs, Params], Tuple[LossTagsVjp, Tuple[loss_functions.LossFunction, ...]]`. """ # Note that this function is independent of any layer tags, hence we can avoid @@ -968,8 +959,6 @@ def loss_tags_hvp( ) -# PyType throws an error with output type annotation: -# -> Tuple[Callable[[utils.FuncArgs], LossTagsVjp], TransformedJaxprFunction] def layer_tags_vjp( func: utils.Func, params_index: int = 0, @@ -1003,7 +992,7 @@ def layer_tags_vjp( Returns: Returns a function that computes primal values and tangents wrt all layer - tags, with signature `Callable[[utils.FuncArgs], LossTagsVjp]`. + tags, with signature `Callable[[FuncArgs, Params], LossTagsVjp]`. """ return cached_transformation( diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index 74f8b14..13f91c0 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -21,22 +21,28 @@ from kfac_jax._src.utils import types # types +Array = types.Array +PRNGKey = types.PRNGKey +Scalar = types.Scalar +Numeric = types.Numeric +Shape = types.Shape +DType = types.DType +PyTree = types.PyTree +ArrayTree = types.ArrayTree +TArrayTree = types.TArrayTree Params = types.Params Batch = types.Batch FuncState = types.FuncState FuncAux = types.FuncAux PyTreeDef = types.PyTreeDef -PyTreeType = types.PyTreeType -PyTree = types.PyTree -TPyTree = types.TPyTree FuncArgs = types.FuncArgs +FuncOuts = types.FuncOuts Func = types.Func ValueFunc = types.ValueFunc ValueAndGradFunc = types.ValueAndGradFunc AssumedFuncOutput = types.AssumedFuncOutput tree_is_empty = types.tree_is_empty abstract_objects_equal = types.abstract_objects_equal -is_array_instance = types.is_array_instance get_float_dtype_and_check_consistency = ( types.get_float_dtype_and_check_consistency) del types diff --git a/kfac_jax/_src/utils/accumulators.py b/kfac_jax/_src/utils/accumulators.py index 86b1a9e..db94b49 100644 --- a/kfac_jax/_src/utils/accumulators.py +++ b/kfac_jax/_src/utils/accumulators.py @@ -14,7 +14,6 @@ """K-FAC for accumulating statistics.""" from typing import Any, Optional, Generic -import chex import jax import jax.numpy as jnp @@ -22,76 +21,107 @@ from kfac_jax._src.utils import parallel from kfac_jax._src.utils import types -PyTree = types.PyTree -TPyTree = types.TPyTree +Array = types.Array +Numeric = types.Numeric +Shape = types.Shape +DType = types.DType +ArrayTree = types.ArrayTree +TArrayTree = types.TArrayTree @misc.pytree_dataclass -class WeightedMovingAverage(Generic[TPyTree]): +class WeightedMovingAverage(Generic[TArrayTree]): """A wrapped class for an arbitrary weighted moving average.""" - weight: chex.Array - raw_value: TPyTree + weight: Numeric + raw_value: Optional[TArrayTree] @property - def value(self) -> TPyTree: + def value(self) -> TArrayTree: """The value of the underlying arrays data structure.""" + if self.raw_value is None: + raise ValueError("`raw_value` has not been set yet.") return jax.tree_util.tree_map(lambda x: x / self.weight, self.raw_value) def update( self, - value: TPyTree, - old_weight_multiplier: chex.Numeric, - new_weight: chex.Numeric, - ) -> None: + value: TArrayTree, + old_weight_multiplier: Numeric, + new_weight: Numeric, + ): """Updates the underlying array and weight accordingly.""" - self.weight = self.weight * old_weight_multiplier + new_weight - self.raw_value = jax.tree_util.tree_map( - lambda x, y: x * old_weight_multiplier + y * new_weight, - self.raw_value, - value, - ) + if self.raw_value is None: + self.raw_value = value + self.weight = jnp.asarray(new_weight).astype(self.weight.dtype) - def sync(self, pmap_axis_name: Optional[str]) -> None: + else: + self.weight = self.weight * old_weight_multiplier + new_weight + self.raw_value = jax.tree_util.tree_map( + lambda x, y: x * old_weight_multiplier + y * new_weight, + self.raw_value, + value, + ) + + def sync(self, pmap_axis_name: Optional[str]): """Syncs the underlying array across devices.""" + if self.raw_value is None: + raise ValueError("`raw_value` has not been set yet.") self.raw_value = parallel.pmean_if_pmap(self.raw_value, pmap_axis_name) + def clear(self, value_to_none: bool = False): + """Resets the weighted average.""" + self.weight = jnp.zeros_like(self.weight) + self.raw_value = None if value_to_none else jnp.zeros_like(self.raw_value) + + def value_and_clear(self) -> TArrayTree: + """Retrieves the value of the weighted average and clears it.""" + value = self.value + self.clear() + return value + + def copy(self) -> "WeightedMovingAverage[TArrayTree]": + """Returns a copy of the PyTree structure (but not the JAX arrays).""" + (flattened, structure) = jax.tree_util.tree_flatten(self) + return jax.tree_util.tree_unflatten(structure, flattened) + @classmethod - def zero( + def zeros_array( cls, - shape: chex.Shape, - dtype: Optional[chex.ArrayDType] = None, - ) -> "WeightedMovingAverage": + shape: Shape, + dtype: Optional[DType] = None, + ) -> "WeightedMovingAverage[Array]": """Initializes a `WeightedMovingAverage` with a single array of zeros.""" return WeightedMovingAverage( weight=jnp.zeros([], dtype=dtype), - raw_value=jnp.zeros(shape, dtype=dtype)) + raw_value=jnp.zeros(shape, dtype=dtype), + ) @classmethod - def zeros_like(cls, value: PyTree) -> "WeightedMovingAverage": + def zeros_like(cls, value: TArrayTree) -> "WeightedMovingAverage[TArrayTree]": """Initializes a `WeightedMovingAverage` with zeros structure like `value`.""" return WeightedMovingAverage( weight=jnp.array( - 0.0, dtype=types.get_float_dtype_and_check_consistency(value)), - raw_value=jax.tree_util.tree_map(jnp.zeros_like, value) + 0.0, dtype=types.get_float_dtype_and_check_consistency(value) + ), + raw_value=jax.tree_util.tree_map(jnp.zeros_like, value), ) - def copy(self): - """Returns a copy of the PyTree structure (but not the JAX arrays).""" - (flattened, structure) = jax.tree_util.tree_flatten(self) - return jax.tree_util.tree_unflatten(structure, flattened) + @classmethod + def empty(cls, dtype: Optional[DType] = None) -> "WeightedMovingAverage[Any]": + """Returns an empty moving average instance.""" + weight = jnp.zeros([]) if dtype is None else jnp.zeros([], dtype=dtype) + return WeightedMovingAverage(weight=weight, raw_value=None) - def clear(self): - self.weight = jnp.zeros_like(self.weight) - self.raw_value = jnp.zeros_like(self.raw_value) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight!r}, {self.raw_value!r})" -class MultiChunkAccumulator(Generic[TPyTree]): +class MultiChunkAccumulator(Generic[TArrayTree]): """Statistics accumulation, abstracted over multiple chunks.""" def __init__( self, - init_obj_value: Optional[TPyTree], - weight: chex.Numeric, + init_obj_value: Optional[TArrayTree], + weight: Numeric, multi_device: bool, ): """Initializes an accumulator instance with the provided object and counter. @@ -108,12 +138,12 @@ def __init__( self._multi_device = multi_device @property - def accumulator(self) -> TPyTree: + def accumulator(self) -> TArrayTree: """The current value of the underlying not-normalized accumulator.""" return self._accumulator @property - def weight(self) -> chex.Numeric: + def weight(self) -> Numeric: """The current normalization weight of the underlying accumulator.""" return self._weight @@ -123,7 +153,7 @@ def multi_device(self) -> bool: return self._multi_device @property - def value(self) -> TPyTree: + def value(self) -> TArrayTree: """The current normalized value of the accumulator.""" if types.tree_is_empty(self.accumulator): @@ -139,13 +169,13 @@ def clear(self) -> None: self._accumulator = None self._weight = None - def value_and_clear(self) -> TPyTree: + def value_and_clear(self) -> TArrayTree: """Retrieves the normalized value of the accumulator and clears it.""" value = self.value self.clear() return value - def add(self, value_obj: TPyTree, weight: chex.Numeric = 1): + def add(self, value_obj: TArrayTree, weight: Numeric = 1): """Adds an element to the moving average and the max. The exact update equation for the statistics are: @@ -164,7 +194,7 @@ def add(self, value_obj: TPyTree, weight: chex.Numeric = 1): self._accumulator = value_obj - if isinstance(weight, types.CHEX_SCALAR_TYPES): + if isinstance(weight, types.SCALAR_TYPES): self._weight = jnp.full_like(self._weight, weight) elif not isinstance(weight, jax.Array): @@ -200,9 +230,9 @@ def add(self, value_obj: TPyTree, weight: chex.Numeric = 1): @classmethod def zeros_like( cls, - obj: TPyTree, + obj: TArrayTree, multi_device: bool - ) -> "MultiChunkAccumulator[TPyTree]": + ) -> "MultiChunkAccumulator[TArrayTree]": """Creates a zero initialized accumulator as `obj`.""" if multi_device: diff --git a/kfac_jax/_src/utils/math.py b/kfac_jax/_src/utils/math.py index d745ce1..69750ea 100644 --- a/kfac_jax/_src/utils/math.py +++ b/kfac_jax/_src/utils/math.py @@ -14,9 +14,8 @@ """K-FAC utilities for various mathematical operations.""" import functools import string -from typing import Callable, Optional, Sequence, Union, Iterable, Tuple +from typing import Callable, Optional, Sequence, Iterable, TypeVar, Tuple, Union -import chex import jax from jax import lax from jax.experimental.sparse import linalg as experimental_splinalg @@ -30,8 +29,12 @@ import tree -PyTree = types.PyTree -TPyTree = types.TPyTree +Array = types.Array +Numeric = types.Numeric +PRNGKey = types.PRNGKey +ArrayTree = types.ArrayTree +TArrayTree = types.TArrayTree +TNumeric = TypeVar("TNumeric", bound=Numeric) _ALPHABET = string.ascii_lowercase @@ -51,7 +54,7 @@ def get_special_case_zero_inv() -> bool: return _SPECIAL_CASE_ZERO_INV -def product(iterable_object: Iterable[chex.Numeric]) -> chex.Numeric: +def product(iterable_object: Iterable[TNumeric]) -> TNumeric: """Computes the product of all elements in the iterable.""" x = 1 @@ -61,7 +64,7 @@ def product(iterable_object: Iterable[chex.Numeric]) -> chex.Numeric: return x -def outer_product(*arrays: chex.Array) -> chex.Array: +def outer_product(*arrays: Array) -> Array: """Computes the outer product of an arbitrary number of vectors.""" if not all(a.ndim == 1 for a in arrays): raise ValueError("All arrays must be vectors.") @@ -70,36 +73,36 @@ def outer_product(*arrays: chex.Array) -> chex.Array: return jnp.einsum(f"{in_str}->{out_str}", *arrays) -def scalar_mul(obj: TPyTree, scalar: chex.Numeric) -> TPyTree: +def scalar_mul(obj: TArrayTree, scalar: Numeric) -> TArrayTree: """Multiplies all PyTree leaves of the object by the provided scalar.""" # The check below is in its current form because of how `jax.jit` tracing # mechanism work. If we use `scalar == 1` and `scalar` is an array, inside a # `jit` context, jax will raise an error, since you are not allowed to use # abstract values in concrete boolean statements, like native python # if/while/for constructs. - if isinstance(scalar, types.CHEX_SCALAR_TYPES) and scalar == 1.0: + if isinstance(scalar, types.SCALAR_TYPES) and scalar == 1.0: return obj return jax.tree_util.tree_map(lambda x: x * scalar, obj) -def scalar_div(obj: TPyTree, scalar: chex.Numeric) -> TPyTree: +def scalar_div(obj: TArrayTree, scalar: Numeric) -> TArrayTree: """Divides all PyTree leaves of the object by the provided scalar.""" # The check below is in its current form because of how `jax.jit` tracing # mechanism work. If we use `scalar == 1` and `scalar` is an array, inside a # `jit` context, jax will raise an error, since you are not allowed to use # abstract values in concrete boolean statements, like native python # if/while/for constructs. - if isinstance(scalar, types.CHEX_SCALAR_TYPES) and scalar == 1.0: + if isinstance(scalar, types.SCALAR_TYPES) and scalar == 1.0: return obj return jax.tree_util.tree_map(lambda x: x / scalar, obj) def weighted_sum_of_objects( - objects: Sequence[TPyTree], - coefficients: Sequence[chex.Numeric], -) -> TPyTree: + objects: Sequence[TArrayTree], + coefficients: Sequence[Numeric], +) -> TArrayTree: """Computes a weighted sum of the objects'. The function computes `sum_i coefficients[i] * objects[i]`. All objects must @@ -131,7 +134,7 @@ def weighted_sum_of_objects( return accumulator -def _inner_product_float64(obj1: PyTree, obj2: PyTree) -> chex.Array: +def _inner_product_float64(obj1: ArrayTree, obj2: ArrayTree) -> Array: """Computes inner product explicitly in float64 precision.""" raise NotImplementedError() @@ -151,19 +154,19 @@ def _inner_product_float64(obj1: PyTree, obj2: PyTree) -> chex.Array: # elements_inner_products = jax.tree_util.tree_map(array_ip, obj1, obj2) # flat_list = jax.tree_util.tree_leaves(elements_inner_products) - # result = flat_list[0] + # result = flat_List[0] - # for element_ip in flat_list[1:]: + # for element_ip in flat_List[1:]: # result = result + element_ip # return jnp.array(result, dtype=original_dtype) def inner_product( - obj1: PyTree, - obj2: PyTree, + obj1: ArrayTree, + obj2: ArrayTree, in_float64: bool = False -) -> chex.Array: +) -> Array: """Computes the inner product ``. To compute the inner product, each of the two input objects is assumed to @@ -194,10 +197,10 @@ def inner_product( def symmetric_matrix_inner_products( - vectors1: Sequence[PyTree], - vectors2: Sequence[PyTree], - ip_function: Callable[[PyTree, PyTree], chex.Array] = inner_product, -) -> chex.Array: + vectors1: Sequence[ArrayTree], + vectors2: Sequence[ArrayTree], + ip_function: Callable[[ArrayTree, ArrayTree], Array] = inner_product, +) -> Array: """Computes a matrix of the inner products between the two sequences. Args: @@ -227,9 +230,9 @@ def symmetric_matrix_inner_products( def matrix_of_inner_products( - vectors: Sequence[PyTree], - ip_function: Callable[[PyTree, PyTree], chex.Array] = inner_product, -) -> chex.Array: + vectors: Sequence[ArrayTree], + ip_function: Callable[[ArrayTree, ArrayTree], Array] = inner_product, +) -> Array: """Computes the matrix of inner products of the sequence of vectors. Args: @@ -246,10 +249,10 @@ def matrix_of_inner_products( def vector_of_inner_products( - base: PyTree, - vectors: Sequence[PyTree], - ip_function: Callable[[PyTree, PyTree], chex.Array] = inner_product, -) -> chex.Array: + base: ArrayTree, + vectors: Sequence[ArrayTree], + ip_function: Callable[[ArrayTree, ArrayTree], Array] = inner_product, +) -> Array: """Computes a vector of inner products with base. Args: @@ -270,10 +273,10 @@ def vector_of_inner_products( def block_permuted( - matrix: chex.Array, + matrix: Array, block_sizes: Sequence[int], block_order: Sequence[int], -) -> chex.Array: +) -> Array: """Permutes whole blocks of the input matrix. Given a square matrix, this function splits it into blocks, each one having @@ -308,7 +311,7 @@ def block_permuted( return jnp.block(reordered_blocks) -def norm(obj: PyTree) -> chex.Array: +def norm(obj: ArrayTree) -> Array: """Computes the Euclidean norm of the provided PyTree object.""" elements_squared_norm = jax.tree_util.tree_map( lambda x: jnp.sum(jnp.square(x)), obj) @@ -316,7 +319,7 @@ def norm(obj: PyTree) -> chex.Array: return jnp.sqrt(sum(jax.tree_util.tree_leaves(elements_squared_norm))) -def per_parameter_norm(obj: PyTree, key_prefix: str) -> PyTree: +def per_parameter_norm(obj: ArrayTree, key_prefix: str) -> ArrayTree: per_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, obj) per_param_norm = tree.flatten_with_path(per_param_norm) @@ -326,7 +329,7 @@ def per_parameter_norm(obj: PyTree, key_prefix: str) -> PyTree: } -def psd_inv_cholesky(matrix: chex.Array, damping: chex.Array) -> chex.Array: +def psd_inv_cholesky(matrix: Array, damping: Array) -> Array: """Computes the inverse of `matrix + damping*I`, with matrix assumed PSD.""" if matrix.shape[:1] != matrix.shape[1:]: @@ -338,11 +341,11 @@ def psd_inv_cholesky(matrix: chex.Array, damping: chex.Array) -> chex.Array: def psd_matrix_norm( - matrix: chex.Array, + matrix: Array, norm_type: str = "avg_trace", method_2norm: str = "lobpcg", - rng_key: Optional[chex.PRNGKey] = None -) -> chex.Array: + rng_key: Optional[PRNGKey] = None +) -> Array: """Computes one of several different matrix norms for PSD matrices. Args: @@ -439,9 +442,9 @@ def psd_matrix_norm( def pi_adjusted_kronecker_inverse( - *arrays: chex.Array, - damping: chex.Numeric, -) -> Tuple[chex.Array, ...]: + *arrays: Array, + damping: Numeric, +) -> Tuple[Array, ...]: """Computes pi-adjusted factored damping inverses. The inverse of `a_1 kron a_2 kron ... kron a_n + damping * I` is not Kronecker @@ -484,7 +487,7 @@ def pi_adjusted_kronecker_inverse( damping = damping.astype(c.dtype) # pytype: disable=attribute-error # numpy-scalars - def regular_inverse() -> Tuple[chex.Array, ...]: + def regular_inverse() -> Tuple[Array, ...]: non_scalars = sum(1 if a.size != 1 else 0 for a in arrays) @@ -516,7 +519,7 @@ def regular_inverse() -> Tuple[chex.Array, ...]: return tuple(u_hats_inv) - def zero_inverse() -> Tuple[chex.Array, ...]: + def zero_inverse() -> Tuple[Array, ...]: # In the special case where for some reason one of the factors is zero, then # the inverse is just `damping^-1 * I`, hence we write each factor as @@ -550,8 +553,8 @@ def zero_inverse() -> Tuple[chex.Array, ...]: def kronecker_product_axis_mul_v( - factors: Sequence[chex.Array], - v: chex.Array, + factors: Sequence[Array], + v: Array, axis_groups: Optional[Sequence[Sequence[int]]] = None, transpose: Union[bool, Sequence[bool]] = False, ): @@ -617,9 +620,9 @@ def kronecker_product_axis_mul_v( def kronecker_eigen_basis_axis_mul_v( - q_factors: Sequence[chex.Array], - eigenvalues: chex.Array, - v: chex.Array, + q_factors: Sequence[Array], + eigenvalues: Array, + v: Array, axis_groups: Optional[Sequence[Sequence[int]]] = None, ): """Computes a matrix-vector product in a Kronecker product eigen-basis. @@ -661,22 +664,22 @@ def kronecker_eigen_basis_axis_mul_v( def kronecker_product_mul_v( - a: chex.Array, - b: chex.Array, - v: chex.Array, + a: Array, + b: Array, + v: Array, a_is_symmetric: bool, -) -> chex.Array: +) -> Array: """Computes `unvec[(a kron b) vec(v)]` for correctly sized input matrices.""" del a_is_symmetric # not used return kronecker_product_axis_mul_v([b, a], v) def kronecker_eigen_basis_mul_v( - q_a: chex.Array, - q_b: chex.Array, - eigenvalues: chex.Array, - v: chex.Array, -) -> chex.Array: + q_a: Array, + q_b: Array, + eigenvalues: Array, + v: Array, +) -> Array: """Computes a matrix-vector product in a Kronecker product eigen-basis. The function computes: @@ -702,7 +705,7 @@ def kronecker_eigen_basis_mul_v( return kronecker_eigen_basis_axis_mul_v([q_b, q_a], eigenvalues, v) -def _host_eigh(x: chex.Array, *_) -> Tuple[chex.Array, chex.Array]: +def _host_eigh(x: Array, *_) -> Tuple[Array, Array]: """This calls the CPU numpy function for eigh.""" shape_s = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype) @@ -712,9 +715,9 @@ def _host_eigh(x: chex.Array, *_) -> Tuple[chex.Array, chex.Array]: def _eigh( - x: chex.Array, + x: Array, force_on_host: bool = False, -) -> Tuple[chex.Array, chex.Array]: +) -> Tuple[Array, Array]: """Computes eigenvectors and eigenvalues, with optionally offloading to cpu.""" if force_on_host: @@ -733,9 +736,9 @@ def _eigh( def safe_psd_eigh( - x: chex.Array, + x: Array, force_on_host: bool = False, -) -> Tuple[chex.Array, chex.Array]: +) -> Tuple[Array, Array]: """Computes the eigenvalue decomposition for a PSD matrix. The function is similar to `jax.numpy.linalg.eigh`, but it clips the returned @@ -769,9 +772,9 @@ def safe_psd_eigh( def loop_and_parallelize_average( - func: Callable[..., PyTree], + func: Callable[..., ArrayTree], max_parallel_size: int, -) -> Callable[..., PyTree]: +) -> Callable[..., ArrayTree]: """Returns a function that computes the average of `func` over any arguments. The returned function is mathematically equivalent to @@ -795,7 +798,7 @@ def loop_and_parallelize_average( vmap_fn = jax.vmap(func) @functools.wraps(func) - def average_func(*args) -> PyTree: + def average_func(*args) -> ArrayTree: lead_axis_sizes = set(x.shape[0] for x in jax.tree_util.tree_leaves(args)) diff --git a/kfac_jax/_src/utils/misc.py b/kfac_jax/_src/utils/misc.py index 0b71bcb..ba5b2a9 100644 --- a/kfac_jax/_src/utils/misc.py +++ b/kfac_jax/_src/utils/misc.py @@ -15,20 +15,21 @@ import abc import dataclasses import functools -from typing import Any, Iterator, Sequence, Type, Union, Tuple +from typing import Any, Iterator, Sequence, Type, Tuple, Union -import chex import jax import jax.numpy as jnp from kfac_jax._src.utils import types -PyTree = types.PyTree -TPyTree = types.TPyTree +Array = types.Array +Numeric = types.Numeric +ArrayTree = types.ArrayTree +TArrayTree = types.TArrayTree def fake_element_from_iterator( - iterator: Iterator[TPyTree], -) -> Tuple[TPyTree, Iterator[TPyTree]]: + iterator: Iterator[TArrayTree], +) -> Tuple[TArrayTree, Iterator[TArrayTree]]: """Returns a zeroed-out initial element of the iterator "non-destructively". This function mutates the input iterator, hence after calling this function @@ -47,7 +48,7 @@ def fake_element_from_iterator( """ init_element = next(iterator) fake_element = jax.tree_util.tree_map(jnp.zeros_like, init_element) - def equivalent_iterator() -> Iterator[PyTree]: + def equivalent_iterator() -> Iterator[ArrayTree]: yield init_element # For some reason unknown to us, "yield from" can fail in certain # circumstances @@ -57,9 +58,9 @@ def equivalent_iterator() -> Iterator[PyTree]: def to_tuple_or_repeat( - x: Union[chex.Numeric, Sequence[chex.Numeric]], + x: Union[Numeric, Sequence[Numeric]], length: int, -) -> Tuple[chex.Numeric, ...]: +) -> Tuple[Numeric, ...]: """Converts `x` to a tuple of fixed length. If `x` is an array, it is split along its last axis to a tuple (assumed to @@ -86,7 +87,7 @@ def to_tuple_or_repeat( raise ValueError(f"Unrecognized type for `x` - {type(x)}.") -def first_dim_is_size(size: int, *args: chex.Array) -> bool: +def first_dim_is_size(size: int, *args: Array) -> bool: """Checks that each element of `args` has first axis size equal to `size`.""" return all(arg.shape[0] == size for arg in args) @@ -215,5 +216,5 @@ def wrapped(*args, **kwargs): def default_batch_size_extractor( batch: types.Batch, -) -> chex.Numeric: +) -> Numeric: return jax.tree_util.tree_leaves(batch)[0].shape[0] diff --git a/kfac_jax/_src/utils/parallel.py b/kfac_jax/_src/utils/parallel.py index 29ca4ea..e795ec3 100644 --- a/kfac_jax/_src/utils/parallel.py +++ b/kfac_jax/_src/utils/parallel.py @@ -16,15 +16,17 @@ import numbers from typing import Callable, Optional, Sequence -import chex import jax from jax import core from jax import lax import jax.numpy as jnp from kfac_jax._src.utils import types -PyTree = types.PyTree -TPyTree = types.TPyTree + +Array = types.Array +Numeric = types.Numeric +PRNGKey = types.PRNGKey +TArrayTree = types.TArrayTree # TODO(jamesmartens,botev): add a test for this function? @@ -46,12 +48,12 @@ def in_pmap(axis_name: Optional[str]) -> bool: def wrap_if_pmap( - p_func: Callable[[TPyTree, str], TPyTree], -) -> Callable[[TPyTree, Optional[str]], TPyTree]: + p_func: Callable[[TArrayTree, str], TArrayTree], +) -> Callable[[TArrayTree, Optional[str]], TArrayTree]: """Wraps `p_func` to be executed only when inside a `jax.pmap` context.""" @functools.wraps(p_func) - def p_func_if_pmap(obj: TPyTree, axis_name: Optional[str]) -> TPyTree: + def p_func_if_pmap(obj: TArrayTree, axis_name: Optional[str]) -> TArrayTree: return p_func(obj, axis_name) if in_pmap(axis_name) else obj @@ -65,35 +67,35 @@ def p_func_if_pmap(obj: TPyTree, axis_name: Optional[str]) -> TPyTree: compute_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i") -def index_if_not_scalar(value: chex.Numeric, index: int = 0) -> chex.Numeric: +def index_if_not_scalar(value: Numeric, index: int = 0) -> Numeric: """Index `value` at axis 0 if it is not a scalar, otherwise return it.""" - if types.is_array_instance(value): + if isinstance(value, Array): if value.ndim > 0: # pytype: disable=attribute-error # numpy-scalars return value[index] else: return value - elif isinstance(value, types.CHEX_SCALAR_TYPES): + elif isinstance(value, (float, int)): return value else: - raise ValueError("The input should be an instance of `chex.Numeric`.") + raise ValueError("The input should be an instance of `Numeric`.") @jax.jit -def get_first(obj: TPyTree) -> TPyTree: +def get_first(obj: TArrayTree) -> TArrayTree: """Index the PyTree leaves `x` of `obj` by `x[0]` if they are not scalars.""" return jax.tree_util.tree_map(index_if_not_scalar, obj) -def get_mean(obj: TPyTree) -> TPyTree: +def get_mean(obj: TArrayTree) -> TArrayTree: """Returns the average of `obj` over different devices.""" return get_first(compute_mean(obj)) -def get_sum(obj: TPyTree) -> TPyTree: +def get_sum(obj: TArrayTree) -> TArrayTree: """Returns the sum of `obj` over different devices.""" return get_first(compute_sum(obj)) @@ -103,7 +105,7 @@ def get_sum(obj: TPyTree) -> TPyTree: jit_zeros_like = jax.jit(lambda x: jax.tree_util.tree_map(jnp.zeros_like, x)) -def replicate_all_local_devices(obj: TPyTree) -> TPyTree: +def replicate_all_local_devices(obj: TArrayTree) -> TArrayTree: """Replicates `obj` to all local Jax devices.""" if types.tree_is_empty(obj): return obj @@ -111,7 +113,7 @@ def replicate_all_local_devices(obj: TPyTree) -> TPyTree: return jax.device_put_replicated(obj, devices=jax.local_devices()) -def make_different_rng_key_on_all_devices(rng: chex.PRNGKey) -> chex.PRNGKey: +def make_different_rng_key_on_all_devices(rng: PRNGKey) -> PRNGKey: """Makes a different PRNG for all Jax devices and processes.""" rng = jax.random.fold_in(rng, jax.process_index()) @@ -125,11 +127,11 @@ def make_different_rng_key_on_all_devices(rng: chex.PRNGKey) -> chex.PRNGKey: static_broadcasted_argnums=1) -def check_and_fix_format_for_pmap(obj: TPyTree) -> TPyTree: +def check_and_fix_format_for_pmap(obj: TArrayTree) -> TArrayTree: """Checks shape[0]==device_count and broadcasts scalars to [device_count].""" device_count = jax.local_device_count() - def check_and_fix(x: chex.Numeric) -> chex.Array: + def check_and_fix(x: Numeric) -> Array: # broadcast any 0D scalars if isinstance(x, numbers.Number) or not x.shape: # pytype: disable=attribute-error # numpy-scalars @@ -147,9 +149,9 @@ def check_and_fix(x: chex.Numeric) -> chex.Array: def host_sync( - obj: TPyTree, - sync_op: Callable[[TPyTree, str], TPyTree], -) -> TPyTree: + obj: TArrayTree, + sync_op: Callable[[TArrayTree, str], TArrayTree], +) -> TArrayTree: """Syncs `obj` across multiple hosts with the operation `sync_op`.""" # The implementation here is to use the pmap syncing mechanisms but with only @@ -180,21 +182,21 @@ def host_sync( return obj -def host_all_gather(x: TPyTree) -> TPyTree: +def host_all_gather(x: TArrayTree) -> TArrayTree: """Gathers on every host the values of the PyTree leaves `x`.""" return host_sync(x, lax.all_gather) -def host_mean(x: TPyTree) -> TPyTree: +def host_mean(x: TArrayTree) -> TArrayTree: """Computes the mean of the PyTree leaves of `x` over multiple hosts.""" return host_sync(x, lax.pmean) def sync_and_divide_value( - value: TPyTree, - counter: chex.Numeric, + value: TArrayTree, + counter: Numeric, axis_name: Optional[str] = None, -) -> TPyTree: +) -> TArrayTree: """Computes the mean of `value` over all hosts and divides it by `counter`.""" value = jax.tree_util.tree_map(lambda x: x / counter, value) return pmean_if_pmap(value, axis_name) @@ -209,7 +211,7 @@ def sync_and_divide_value( # We might be able to change this to "return jnp.array(x)" in newer JAX versions -def copy_array(x: chex.Array) -> chex.Array: +def copy_array(x: Array) -> Array: """Copies a Jax array so that it can be donated freely.""" return x + jnp.zeros_like(x) @@ -219,9 +221,9 @@ def copy_array(x: chex.Array) -> chex.Array: def distribute_thunks( - thunks: Sequence[Callable[[], PyTree]], + thunks: Sequence[Callable[[], TArrayTree]], pmap_axis_name: str, - ) -> PyTree: + ) -> TArrayTree: """Distributes the computation of a list of thunks over the pmapped devices. Given a list of thunks, this function distributes their computation over the diff --git a/kfac_jax/_src/utils/staging.py b/kfac_jax/_src/utils/staging.py index f088b01..7073987 100644 --- a/kfac_jax/_src/utils/staging.py +++ b/kfac_jax/_src/utils/staging.py @@ -14,7 +14,7 @@ """K-FAC utilities for classes with staged methods.""" import functools import operator -from typing import Any, Callable, Optional, Sequence, Union, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -23,7 +23,7 @@ from kfac_jax._src.utils import parallel from kfac_jax._src.utils import types -PyTree = types.PyTree +TArrayTree = types.TArrayTree class WithStagedMethods(misc.Finalizable): @@ -112,18 +112,18 @@ def staging_context(self) -> "StagingContext": """Returns a staging context manager, linked to this instance.""" return self.StagingContext(self) - def get_first(self, obj: PyTree) -> PyTree: + def get_first(self, obj: TArrayTree) -> TArrayTree: """Indexes the `obj` PyTree leaves over leading axis if `multi_device`.""" return parallel.get_first(obj) if self.multi_device else obj - def copy_obj(self, obj: PyTree) -> PyTree: + def copy_obj(self, obj: TArrayTree) -> TArrayTree: """Copies the object.""" if self.multi_device: return parallel.pmap_copy_obj(obj) else: return parallel.copy_obj(obj) - def replicate(self, obj: PyTree) -> PyTree: + def replicate(self, obj: TArrayTree) -> TArrayTree: """Replicates the object to all local devices if `multi_device`.""" if self.multi_device: return parallel.replicate_all_local_devices(obj) @@ -132,10 +132,10 @@ def replicate(self, obj: PyTree) -> PyTree: def staged( - method: Callable[..., PyTree], + method: Callable[..., TArrayTree], static_argnums: Optional[Union[int, Sequence[int]]] = None, donate_argnums: Optional[Union[int, Sequence[int]]] = None, -) -> Callable[..., PyTree]: +) -> Callable[..., TArrayTree]: """Makes the instance method staged. This decorator **should** only be applied to instance methods of classes that @@ -186,7 +186,7 @@ def try(self, x): donate_argnums=donate_argnums) @functools.wraps(method) - def decorated(instance: "WithStagedMethods", *args: Any) -> PyTree: + def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: if instance.in_staging: return method(instance, *args) diff --git a/kfac_jax/_src/utils/types.py b/kfac_jax/_src/utils/types.py index 91ab8ac..d81171e 100644 --- a/kfac_jax/_src/utils/types.py +++ b/kfac_jax/_src/utils/types.py @@ -12,45 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. """K-FAC annotation types and general tree operations.""" -import sys -from typing import Any, Callable, Sequence, TypeVar, Union, Tuple +from typing import Callable, TypeVar, Sequence, Mapping, Tuple, Union -import chex import jax import jax.numpy as jnp # Types for annotation T = TypeVar("T") -Params = chex.ArrayTree -Batch = chex.ArrayTree -FuncState = chex.ArrayTree -FuncAux = chex.ArrayTree -PyTreeDef = chex.PyTreeDef -PyTreeType = Any -PyTree = chex.ArrayTree -TPyTree = TypeVar("TPyTree", bound=PyTree) -FuncArgs = Sequence[PyTree] -Func = Callable[..., Union[chex.Array, Tuple[chex.Array, FuncAux]]] -ValueFunc = Callable[..., chex.Array] -ValueAndGradFunc = Callable[..., Tuple[chex.Array, Params]] - -AssumedFuncOutput = Union[ - chex.Array, - Tuple[chex.Array, FuncAux], - Tuple[chex.Array, Tuple[FuncState, FuncAux]], -] - -CHEX_SCALAR_TYPES = (float, int) - - -def tree_is_empty(obj: PyTree) -> bool: +Array = jax.Array +PRNGKey = Array +Scalar = Union[float, int] +Numeric = Union[Array, Scalar] +Shape = Tuple[int, ...] +DType = jnp.dtype +PyTree = Union[T, Sequence["PyTree[T]"], Mapping[str, "PyTree[T]"]] +ArrayTree = PyTree[Array] +TArrayTree = TypeVar("TArrayTree", bound=ArrayTree) +Params = TypeVar("Params", bound=ArrayTree) +Batch = TypeVar("Batch", bound=ArrayTree) +FuncState = TypeVar("FuncState", bound=ArrayTree) +FuncAux = TypeVar("FuncAux", bound=ArrayTree) +PyTreeDef = jax.tree_util.PyTreeDef +FuncArgs = Sequence[ArrayTree] +FuncOuts = Union[Array, Tuple[Array, FuncAux]] +Func = Callable[..., FuncOuts] +ValueFunc = Callable[..., Array] +ValueAndGradFunc = Callable[..., Tuple[Array, Params]] +AssumedFuncOutput = Union[Array, Tuple[Array, FuncAux], + Tuple[Array, Tuple[FuncState, FuncAux]]] +SCALAR_TYPES = (float, int) + + +def tree_is_empty(obj: ArrayTree) -> bool: """Returns whether the given PyTree is empty.""" return not jax.tree_util.tree_leaves(obj) def abstract_objects_equal( - obj1: PyTree, - obj2: PyTree, + obj1: ArrayTree, + obj2: ArrayTree, check_dtype: bool = True ) -> bool: """`True` if the objects have the same PyTree structure, shapes and dtypes.""" @@ -61,26 +61,7 @@ def abstract_objects_equal( jax.tree_util.tree_leaves(obj2)))) -def is_array_instance(var: chex.Numeric) -> bool: - """Return true if var is a instance of a jax or numpy array type.""" - if sys.version_info >= (3, 10): - return isinstance(var, chex.Array) - else: - # python 3.9 and earlier don't support instance on Generics (e.g. Union). - # Instead fallback to comparing to a tuple which (currently) matches - # chex.Array. - array_types = ( - chex.ArrayDevice, - chex.ArrayNumpy, - chex.ArrayBatched, - chex.ArraySharded, - ) - return isinstance(var, array_types) - - -def get_float_dtype_and_check_consistency( - obj: PyTree -) -> jnp.dtype: +def get_float_dtype_and_check_consistency(obj: ArrayTree) -> DType: """Checks that all leaves have the same float dtype, and returns this.""" leaves = jax.tree_util.tree_leaves(obj) @@ -89,7 +70,7 @@ def get_float_dtype_and_check_consistency( for leaf in leaves: - if leaf.dtype in {jnp.float32, jnp.float64}: # include bfloat16 etc? + if leaf.dtype in (jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64): if dtype is not None and leaf.dtype != dtype: raise ValueError("Inconsistent dtypes detected.") diff --git a/readthedocs.yml b/readthedocs.yml index 78e4971..53d7d28 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -9,7 +9,7 @@ sphinx: fail_on_warning: false python: - version: 3.7 + version: 3.8 install: - requirements: requirements_docs.txt - requirements: requirements.txt diff --git a/requirements.txt b/requirements.txt index 37d9bc9..cf75ca4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ absl-py>=0.12.0 immutabledict>=2.2.1 -numpy>=1.19.5 -distrax>=0.1.2 -chex>=0.1.5 -jax>=0.3.17 -jaxlib>=0.3.15 +numpy>=1.21 +distrax>=0.1.3 +jax>=0.4.7 +jaxlib>=0.4.7 dm-tree>=0.1.7 optax>=0.1.4 typing_extensions>=4.2.0; python_version<"3.10" diff --git a/requirements_tests.txt b/requirements_tests.txt index d62df6d..c10aea9 100644 --- a/requirements_tests.txt +++ b/requirements_tests.txt @@ -1,11 +1,10 @@ pytest-xdist absl-py==0.12.0 immutabledict==2.2.1 -numpy==1.20 -distrax==0.1.2 -chex==0.1.5 -jax==0.3.17 -jaxlib==0.3.15 -dm-haiku==0.0.7 +numpy==1.21 +distrax==0.1.3 +jax==0.4.7 +jaxlib==0.4.7 +dm-haiku==0.0.9 dm-tree==0.1.7 optax==0.1.4 diff --git a/setup.py b/setup.py index 556b4cd..2663e53 100644 --- a/setup.py +++ b/setup.py @@ -51,22 +51,27 @@ def _parse_requirements(requirements_txt_path): url="https://github.com/deepmind/kfac-jax", license="Apache 2.0", author="DeepMind", - description="A Jax package for approximate curvature estimation and " - "optimization using KFAC.", + description=( + "A Jax package for approximate curvature estimation and " + "optimization using KFAC." + ), long_description=open(os.path.join(_CURRENT_DIR, "README.md")).read(), long_description_content_type="text/markdown", author_email="kfac-jax-dev@google.com", # Contained modules and scripts. packages=setuptools.find_namespace_packages(exclude=["tests", "examples"]), install_requires=_parse_requirements( - os.path.join(_CURRENT_DIR, "requirements.txt")), + os.path.join(_CURRENT_DIR, "requirements.txt") + ), tests_require=_parse_requirements( - os.path.join(_CURRENT_DIR, "requirements_tests.txt")), + os.path.join(_CURRENT_DIR, "requirements_tests.txt") + ), extras_require={ "tests": _parse_requirements( - os.path.join(_CURRENT_DIR, "requirements_tests.txt")), + os.path.join(_CURRENT_DIR, "requirements_tests.txt") + ), }, - requires_python=">=3.7", + requires_python=">=3.8", include_package_data=True, zip_safe=False, # PyPI package information. @@ -77,7 +82,6 @@ def _parse_requirements(requirements_txt_path): "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 0c311cc..47280b1 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -24,6 +24,8 @@ from tests import models import numpy as np +StateType = kfac_jax.curvature_estimator.StateType + NON_LINEAR_MODELS_AND_CURVATURE_TYPE = [ model + ("ggn",) for model in models.NON_LINEAR_MODELS @@ -48,12 +50,12 @@ @functools.partial(jax.jit, static_argnums=(0, 3, 4)) def compute_exact_approx_curvature( - estimator: kfac_jax.CurvatureEstimator, + estimator: kfac_jax.CurvatureEstimator[StateType], rng: chex.PRNGKey, func_args: kfac_jax.utils.FuncArgs, batch_size: int, curvature_type: str, -) -> kfac_jax.curvature_estimator.StateType: +) -> StateType: """Computes the full Fisher matrix approximation for the estimator.""" state = estimator.init( rng=rng,