Skip to content

Commit

Permalink
Expand unit testing of autoscale interpreter decorator.
Browse files Browse the repository at this point in the history
Using parameterized testing to easily extend test coverage.
Allowed catching an issue with scalar Numpy constants & output JAX PyTree.
  • Loading branch information
balancap committed Nov 16, 2023
1 parent 6373362 commit 7afba28
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 56 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
10 changes: 10 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,13 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
"""
# TODO: check scalars as well?
return isinstance(val, (jax.Array, np.ndarray, ScaledArray, int, float))
47 changes: 37 additions & 10 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,26 @@
from typing import Any, Dict

import jax
import numpy as np
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray
from .datatype import NDArray, ScaledArray

_scaled_ops_registry: Dict[core.Primitive, Any] = {}


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.
Only supporting Numpy scalars at the moment.
"""
# TODO: generalized rules!
assert val.shape == ()
assert np.issubdtype(val.dtype, np.floating)
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))


def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
"""Register the scaled translation of JAX primitive.
Expand Down Expand Up @@ -57,11 +69,16 @@ def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
aval_args = safe_map(lambda x: x.aval, args)
# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)
# Trace the graph & convert to scaled one.
outputs_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_flat)
output = jax.tree_util.tree_unflatten(out_pytree, outputs_flat)
return output

return wrapped

Expand All @@ -77,11 +94,24 @@ def read(var):
def write(var, val):
env[var] = val

def to_scaled_array(val):
if isinstance(val, ScaledArray):
return val
elif isinstance(val, np.ndarray):
return numpy_constant_scaled_array(val)
raise TypeError(f"Can not convert '{val}' to a scaled array.")

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
# Make sure all inputs are scaled arrays
invals = list(map(to_scaled_array, invals))
assert all([isinstance(v, ScaledArray) for v in invals])
# TODO: handle `stop_scale` case? integer/boolean dtypes?

# Primitive is supported by `autoscale`?
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
Expand All @@ -90,7 +120,4 @@ def write(var, val):
safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
if len(outvals) == 1:
return outvals[0]
else:
return outvals
return outvals
10 changes: 9 additions & 1 deletion tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics import ScaledArray, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, is_scaled_leaf, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
Expand Down Expand Up @@ -74,3 +74,11 @@ def test__scaled_array__numpy_array_interface(self, npapi):
out = np.asarray(sarr)
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)

def test__is_scaled_leaf__consistent_with_jax(self):
assert is_scaled_leaf(8)
assert is_scaled_leaf(2.0)
assert is_scaled_leaf(np.array(3))
assert is_scaled_leaf(np.array([3]))
assert is_scaled_leaf(jnp.array([3]))
assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16))
101 changes: 57 additions & 44 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import chex
import jax
import jax.numpy as jnp

# import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
Expand All @@ -15,53 +17,64 @@ def test__register_scaled_op__error_if_already_registered(self):
register_scaled_op(jax.lax.mul_p, lambda a, _: a)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def test__autoscale_interpreter__proper_signature(self):
def func(x):
return x

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32)
scaled_outputs = asfunc(scaled_inputs)
expected = jnp.array([1.0, 2.0])

assert isinstance(scaled_outputs, ScaledArray)
npt.assert_array_almost_equal(scaled_outputs, expected)
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr
return x * 2

scaled_func = self.variant(autoscale(func))
scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32)
jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray
assert jaxpr.invars[0].aval.shape == scaled_inputs.shape
assert jaxpr.invars[0].aval.shape == scaled_input.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[0].aval.shape == scaled_input.shape
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_mul__no_attributes(self):
def func(x, y):
return x * y

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_convert_element_type__attributes_passing(self):
def func(x):
return jax.lax.convert_element_type(x, np.float16)

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))
x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)
out = asfunc(x)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float16
npt.assert_array_almost_equal(out, x)
@parameterized.parameters(
# Identity function!
{"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Non-trivial output JAX pytree
{"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Multi-inputs operation.
{
"fn": lambda x, y: x * y,
"inputs": [scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), scaled_array([1.5, 1.5], 2, dtype=np.float32)],
},
# Proper forwarding of attributes.
{
"fn": lambda x: jax.lax.convert_element_type(x, np.float16),
"inputs": [scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)],
},
# Proper constant scalar handling.
{
"fn": lambda x: x * 2,
"inputs": [scaled_array([[-2.0, 0.5]], 0.5, dtype=np.float32)],
},
# TODO/FIXME: Proper constant Numpy array handling.
# {
# "fn": lambda x: x * np.array([2.0, 3.0], dtype=np.float32),
# "inputs": [scaled_array([[-2.0], [0.5]], 0.5, dtype=np.float32)],
# },
)
def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, inputs):
# Autoscale function + (optional) jitting.
scaled_fn = self.variant(autoscale(fn))
scaled_output = scaled_fn(*inputs)
# Normal JAX path, without scaled arrays.
raw_inputs = [np.asarray(v) for v in inputs]
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
if not isinstance(expected_output, (np.ndarray, jax.Array)):
assert type(scaled_output) is type(expected_output)

# Check each output in the flatten tree.
scaled_outputs_flat, _ = jax.tree_util.tree_flatten(scaled_output, is_leaf=is_scaled_leaf)
expected_outputs_flat, _ = jax.tree_util.tree_flatten(expected_output)
for scaled_out, exp_out in zip(scaled_outputs_flat, expected_outputs_flat):
assert isinstance(scaled_out, ScaledArray)
assert scaled_out.scale.shape == ()
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

0 comments on commit 7afba28

Please sign in to comment.