Skip to content

Commit

Permalink
* Modifying the structure of the experiments class by separating the …
Browse files Browse the repository at this point in the history
…Supervised experiment from the Jaxline experiment inheritance.

* Making all data iterators be lazily created.
* Modifying the State classes to be decomposable to dictionaries - now we can serialize every State class into native simple python structures (e.g. dict, tuple, list) and then reconstruct it for those. This can be useful for storing the class in a simpler structure.

PiperOrigin-RevId: 532473167
  • Loading branch information
botev authored and KfacJaxDev committed May 16, 2023
1 parent 3be9b1a commit 95df497
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 180 deletions.
360 changes: 214 additions & 146 deletions examples/training.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = utils.Shape
DType = utils.DType
ScalarOrSequence = Union[Scalar, Sequence[Scalar]]
Cache = Dict[str, Union[Array, Dict[str, Array]]]

# Special global variables
# The default value that would be used for the argument
Expand Down Expand Up @@ -127,7 +128,7 @@ class CurvatureBlock(utils.Finalizable):
you would have to explicitly specify all powers that you will need to cache.
"""

@utils.pytree_dataclass
@utils.register_state_class
class State(utils.State):
"""Persistent state of the block.
Expand Down Expand Up @@ -633,7 +634,7 @@ def _to_dense_unscaled(self, state: CurvatureBlock.State) -> Array:
class Diagonal(CurvatureBlock, abc.ABC):
"""An abstract class for approximating only the diagonal of curvature."""

@utils.pytree_dataclass
@utils.register_state_class
class State(CurvatureBlock.State):
"""Persistent state of the block.
Expand Down Expand Up @@ -743,7 +744,7 @@ def _to_dense_unscaled(self, state: "Diagonal.State") -> Array:
class Full(CurvatureBlock, abc.ABC):
"""An abstract class for approximating the block matrix with a full matrix."""

@utils.pytree_dataclass
@utils.register_state_class
class State(CurvatureBlock.State):
"""Persistent state of the block.
Expand Down Expand Up @@ -986,7 +987,7 @@ def _to_dense_unscaled(self, state: "Full.State") -> Array:
class TwoKroneckerFactored(CurvatureBlock, abc.ABC):
"""An abstract class for approximating the block with a Kronecker product."""

@utils.pytree_dataclass
@utils.register_state_class
class State(CurvatureBlock.State):
"""Persistent state of the block.
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ class BlockDiagonalCurvature(
CurvatureEstimator["BlockDiagonalCurvature.State"]):
"""Block diagonal curvature estimator class."""

@utils.pytree_dataclass
@utils.register_state_class
class State(utils.State):
"""Persistent state of the estimator.
Expand Down
11 changes: 10 additions & 1 deletion kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
class Optimizer(utils.WithStagedMethods):
"""The K-FAC optimizer."""

@utils.pytree_dataclass
@utils.register_state_class
class State(Generic[Params], utils.State):
r"""Persistent state of the optimizer.
Expand All @@ -81,6 +81,15 @@ class State(Generic[Params], utils.State):
data_seen: Numeric
step_counter: Numeric

@classmethod
def from_dict(cls, dict_representation: Dict[str, Any]) -> OptimizerState:
dict_representation["estimator_state"] = (
curvature_estimator.BlockDiagonalCurvature.State.from_dict(
dict_representation["estimator_state"]
)
)
return cls(**dict_representation)

def __init__(
self,
value_and_grad_func: ValueAndGradFunc,
Expand Down
1 change: 1 addition & 0 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def graph(self) -> JaxprGraph:
clean_broadcasts=True,
)
object.__setattr__(self, "_graph", graph)
assert self._graph is not None
return self._graph

def tag_ctor(
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
default_batch_size_extractor = misc.default_batch_size_extractor
auto_scope_function = misc.auto_scope_function
auto_scope_method = misc.auto_scope_method
pytree_dataclass = misc.pytree_dataclass
register_state_class = misc.register_state_class
Finalizable = misc.Finalizable
State = misc.State
del misc
Expand Down
14 changes: 2 additions & 12 deletions kfac_jax/_src/utils/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@
TArrayTree = types.TArrayTree


@misc.pytree_dataclass
class WeightedMovingAverage(Generic[TArrayTree]):
@misc.register_state_class
class WeightedMovingAverage(Generic[TArrayTree], misc.State):
"""A wrapped class for an arbitrary weighted moving average."""
weight: Numeric
raw_value: Optional[TArrayTree]

@property
def value(self) -> TArrayTree:
"""The value of the underlying arrays data structure."""
if self.raw_value is None:
raise ValueError("`raw_value` has not been set yet.")
return jax.tree_util.tree_map(lambda x: x / self.weight, self.raw_value)

def update(
Expand Down Expand Up @@ -78,11 +76,6 @@ def value_and_clear(self) -> TArrayTree:
self.clear()
return value

def copy(self) -> "WeightedMovingAverage[TArrayTree]":
"""Returns a copy of the PyTree structure (but not the JAX arrays)."""
(flattened, structure) = jax.tree_util.tree_flatten(self)
return jax.tree_util.tree_unflatten(structure, flattened)

@classmethod
def zeros_array(
cls,
Expand Down Expand Up @@ -111,9 +104,6 @@ def empty(cls, dtype: Optional[DType] = None) -> "WeightedMovingAverage[Any]":
weight = jnp.zeros([]) if dtype is None else jnp.zeros([], dtype=dtype)
return WeightedMovingAverage(weight=weight, raw_value=None)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight!r}, {self.raw_value!r})"


class MultiChunkAccumulator(Generic[TArrayTree]):
"""Statistics accumulation, abstracted over multiple chunks."""
Expand Down
113 changes: 98 additions & 15 deletions kfac_jax/_src/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import abc
import dataclasses
import functools
from typing import Any, Iterator, Sequence, Type, Tuple, Union
from typing import Any, Iterator, Sequence, Type, Tuple, Union, Dict, TypeVar

import jax
import jax.numpy as jnp
Expand All @@ -25,6 +25,11 @@
Numeric = types.Numeric
ArrayTree = types.ArrayTree
TArrayTree = types.TArrayTree
StateType = TypeVar("StateType")
StateTree = types.PyTree["State"]


STATE_CLASSES_SERIALIZATION_DICT = {}


def fake_element_from_iterator(
Expand Down Expand Up @@ -92,7 +97,45 @@ def first_dim_is_size(size: int, *args: Array) -> bool:
return all(arg.shape[0] == size for arg in args)


def pytree_dataclass(class_type: Type[Any]) -> Type[Any]:
class State(abc.ABC):
"""Abstract class for state classes."""

@classmethod
def field_names(cls) -> Tuple[str, ...]:
return tuple(field.name for field in dataclasses.fields(cls)) # pytype: disable=wrong-arg-types

@classmethod
def field_types(cls) -> Dict[str, Type[Any]]:
return {field.name: field.type for field in dataclasses.fields(cls)} # pytype: disable=wrong-arg-types

@property
def field_values(self) -> Tuple[ArrayTree, ...]:
return tuple(getattr(self, name) for name in self.field_names())

def copy(self):
"""Returns a copy of the PyTree structure (but not the JAX arrays)."""
(flattened, structure) = jax.tree_util.tree_flatten(self)
return jax.tree_util.tree_unflatten(structure, flattened)

def tree_flatten(self) -> Tuple[Tuple[ArrayTree, ...], None]:
return self.field_values, None

@classmethod
def tree_unflatten(
cls,
aux_data: None,
children: Tuple[ArrayTree, ...],
):
del aux_data # not used
return cls(**dict(zip(cls.field_names(), children)))

def __repr__(self) -> str:
return (f"{self.__class__.__name__}(" +
",".join(f"{name}={v!r}" for name, v in self.field_values) +
")")


def register_state_class(class_type: Type[Any]) -> Type[Any]:
"""Extended dataclass decorator, which also registers the class as a PyTree.
The function is equivalent to `dataclasses.dataclass`, but additionally
Expand All @@ -106,27 +149,67 @@ def pytree_dataclass(class_type: Type[Any]) -> Type[Any]:
The transformed `class_type` which is now a dataclass and also registered as
a PyTree.
"""
if not issubclass(class_type, State):
raise ValueError(
f"Class {class_type} is not a subclass of kfac_jax.utils.State."
)

class_type = dataclasses.dataclass(class_type)
fields_names = tuple(field.name for field in dataclasses.fields(class_type))
class_type = jax.tree_util.register_pytree_node_class(class_type)
class_name = f"{class_type.__module__}.{class_type.__qualname__}"
STATE_CLASSES_SERIALIZATION_DICT[class_name] = class_type
return class_type

def flatten(instance) -> Tuple[Tuple[Any, ...], Any]:
return tuple(getattr(instance, name) for name in fields_names), None

def unflatten(_: Any, args: Sequence[Any]) -> Any:
return class_type(*args)
def serialize_state_tree(instance: StateTree) -> ArrayTree:
"""Returns a recursively constructed dictionary of the state."""
if isinstance(instance, State):
result_dict = {name: serialize_state_tree(getattr(instance, name))
for name in instance.field_names()}
cls = instance.__class__
result_dict["__class__"] = f"{cls.__module__}.{cls.__qualname__}"
return result_dict

jax.tree_util.register_pytree_node(class_type, flatten, unflatten)
elif isinstance(instance, list):
return [serialize_state_tree(v) for v in instance]

return class_type
elif isinstance(instance, tuple):
return tuple(serialize_state_tree(v) for v in instance)

elif isinstance(instance, set):
return set(serialize_state_tree(v) for v in instance)

@pytree_dataclass
class State(object):
elif isinstance(instance, dict):
return {k: serialize_state_tree(v) for k, v in instance.items()}

def copy(self):
"""Returns a copy of the PyTree structure (but not the JAX arrays)."""
(flattened, structure) = jax.tree_util.tree_flatten(self)
return jax.tree_util.tree_unflatten(structure, flattened)
else:
return instance


def deserialize_state_tree(representation: ArrayTree) -> StateTree:
"""Returns the state class using a recursively constructed."""
if isinstance(representation, list):
return [deserialize_state_tree(v) for v in representation]

elif isinstance(representation, tuple):
return tuple(deserialize_state_tree(v) for v in representation)

elif isinstance(representation, set):
return set(deserialize_state_tree(v) for v in representation)

elif isinstance(representation, dict):
if "__class__" not in representation:
return {k: deserialize_state_tree(v) for k, v in representation.items()}

class_name = representation.pop("__class__")
if class_name not in STATE_CLASSES_SERIALIZATION_DICT:
raise ValueError(f"Did not find how to reconstruct class {class_name}.")

dict_rep = deserialize_state_tree(representation)
return STATE_CLASSES_SERIALIZATION_DICT[class_name](**dict_rep)

else:
return representation


class Finalizable(abc.ABC):
Expand Down

0 comments on commit 95df497

Please sign in to comment.