Skip to content

Commit

Permalink
Implementation of jax_scalify.tree sub-module.
Browse files Browse the repository at this point in the history
PyTree methods adapted to `ScaledArray`: `all`, `flatten`, `leaves`, `map`, `structure`, `unflatten`,
in `jax_scalify.tree`.
Additionally, implementing `astype` as quite useful method on PyTrees!
  • Loading branch information
balancap committed Jun 27, 2024
1 parent e25f4cc commit bae9036
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax_scalify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import core, lax, ops
from . import core, lax, ops, tree
from ._version import __version__
from .core import ( # noqa: F401
Pow2RoundMode,
Expand Down
2 changes: 2 additions & 0 deletions jax_scalify/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from .tree_util import all, astype, flatten, leaves, map, structure, unflatten # noqa: F401
125 changes: 125 additions & 0 deletions jax_scalify/tree/tree_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from typing import Any, Callable

import jax
import jax.numpy as jnp
from jax import tree_util

from jax_scalify.core import DTypeLike, is_scaled_leaf

Leaf = Any


def astype(tree: Any, dtype: DTypeLike, floating_only: bool = False) -> Any:
"""Map `astype` method to all pytree leaves, `Array` or `ScaledArray`.
Args:
tree: the pytree to cast.
dtype: Dtype to cast to.
floating_only: Only convert leaves with floating datatype.
Returns:
A new PyTree with the same structure, with casting to new dtype.
"""
if floating_only:
# Convert only leaves with floating dtype.
cast_fn = lambda v: v.astype(dtype) if jnp.issubdtype(v.dtype, jnp.floating) else v
return tree_util.tree_map(cast_fn, tree, is_leaf=is_scaled_leaf)
return tree_util.tree_map(lambda v: v.astype(dtype), tree, is_leaf=is_scaled_leaf)


def all(tree: Any) -> bool:
"""Call all() over the leaves of a tree, `Array` or `ScaledArray`
Args:
tree: the pytree to evaluate
Returns:
result: boolean True or False
"""
return all(jax.tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf))


def flatten(tree: Any) -> tuple[list[Leaf], tree_util.PyTreeDef]:
"""Flattens a pytree, with `Array` or `ScaledArray` leaves.
The flattening order (i.e. the order of elements in the output list)
is deterministic, corresponding to a left-to-right depth-first tree
traversal.
Args:
tree: a pytree to flatten.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
See Also:
- :func:`jax_scalify.tree.leaves`
- :func:`jax_scalify.tree.structure`
- :func:`jax_scalify.tree.unflatten`
"""
return tree_util.tree_flatten(tree, is_leaf=is_scaled_leaf)


def leaves(
tree: Any,
) -> list[Leaf]:
"""Gets the leaves (`Array` or `ScaledArray`) of a pytree.
Args:
tree: the pytree for which to get the leaves
Returns:
leaves: a list of tree leaves.
See Also:
- :func:`jax_scalify.tree.flatten`
- :func:`jax_scalify.tree.structure`
- :func:`jax_scalify.tree.unflatten`
"""
return tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf)


def map(f: Callable[..., Any], tree: Any, *rest: Any) -> Any:
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to ``f``.
rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
``rest``.
See Also:
- :func:`jax_scalify.tree.leaves`
- :func:`jax_scalify.tree.reduce`
"""
return tree_util.tree_map(f, tree, *rest, is_leaf=is_scaled_leaf)


def structure(tree: Any) -> tree_util.PyTreeDef:
"""Gets the treedef for a pytree, with `Array` or `ScaledArray` leaves.
Args:
tree: the pytree for which to get the leaves
Returns:
pytreedef: a PyTreeDef representing the structure of the tree.
See Also:
- :func:`jax_scalify.tree.flatten`
- :func:`jax_scalify.tree.leaves`
- :func:`jax_scalify.tree.unflatten`
"""
return tree_util.tree_structure(tree, is_leaf=is_scaled_leaf)


# Alias of JAX tree unflatten.
unflatten = jax.tree.unflatten
51 changes: 51 additions & 0 deletions tests/tree/test_tree_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import chex
import numpy as np

import jax_scalify as jsa


class ScalifyTreeUtilTests(chex.TestCase):
def test__tree_flatten__proper_result(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs, _ = jsa.tree.flatten(values)
assert len(outputs) == 2
assert outputs[0] == 2
assert isinstance(outputs[1], jsa.ScaledArray)
assert np.asarray(outputs[1]) == 1.5

def test__tree_leaves__proper_result(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs = jsa.tree.leaves(values)
assert len(outputs) == 2
assert outputs[0] == 2
assert isinstance(outputs[1], jsa.ScaledArray)
assert np.asarray(outputs[1]) == 1.5

def test__tree_structure__proper_result(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
pytree = jsa.tree.structure(values)
assert pytree == jsa.tree.flatten(values)[1]

def test__tree_unflatten__proper_result(self):
values_in = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs, pytree = jsa.tree.flatten(values_in)
values_out = jsa.tree.unflatten(pytree, outputs)
assert values_out == values_in

def test__tree_map__proper_result(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs = jsa.tree.map(lambda v: v.dtype, values)
assert outputs == {"a": np.int32, "b": np.float32}

def test__tree_astype__all_leaves_casting(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs = jsa.tree.astype(values, dtype=np.float16)
dtypes = jsa.tree.map(lambda v: v.dtype, outputs)
assert dtypes == {"a": np.float16, "b": np.float16}

def test__tree_astype__only_float_casting(self):
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)}
outputs = jsa.tree.astype(values, dtype=np.float16, floating_only=True)
dtypes = jsa.tree.map(lambda v: v.dtype, outputs)
assert dtypes == {"a": np.int32, "b": np.float16}

0 comments on commit bae9036

Please sign in to comment.