Skip to content

Commit

Permalink
Expand unit testing of autoscale interpreter decorator. (#19)
Browse files Browse the repository at this point in the history
Using parameterized testing to easily extend test coverage.
Allowed catching an issue with scalar Numpy constants & output JAX PyTree.
  • Loading branch information
balancap authored Nov 16, 2023
1 parent 6373362 commit d476f74
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 56 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
10 changes: 10 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,13 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
"""
# TODO: check scalars as well?
return isinstance(val, (jax.Array, np.ndarray, ScaledArray, int, float))
47 changes: 37 additions & 10 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,26 @@
from typing import Any, Dict

import jax
import numpy as np
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray
from .datatype import NDArray, ScaledArray

_scaled_ops_registry: Dict[core.Primitive, Any] = {}


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.
Only supporting Numpy scalars at the moment.
"""
# TODO: generalized rules!
assert val.shape == ()
assert np.issubdtype(val.dtype, np.floating)
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))


def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
"""Register the scaled translation of JAX primitive.
Expand Down Expand Up @@ -57,11 +69,16 @@ def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
aval_args = safe_map(lambda x: x.aval, args)
# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)
# Trace the graph & convert to scaled one.
outputs_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_flat)
output = jax.tree_util.tree_unflatten(out_pytree, outputs_flat)
return output

return wrapped

Expand All @@ -77,11 +94,24 @@ def read(var):
def write(var, val):
env[var] = val

def to_scaled_array(val):
if isinstance(val, ScaledArray):
return val
elif isinstance(val, np.ndarray):
return numpy_constant_scaled_array(val)
raise TypeError(f"Can not convert '{val}' to a scaled array.")

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
# Make sure all inputs are scaled arrays
invals = list(map(to_scaled_array, invals))
assert all([isinstance(v, ScaledArray) for v in invals])
# TODO: handle `stop_scale` case? integer/boolean dtypes?

# Primitive is supported by `autoscale`?
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
Expand All @@ -90,7 +120,4 @@ def write(var, val):
safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
if len(outvals) == 1:
return outvals[0]
else:
return outvals
return outvals
10 changes: 9 additions & 1 deletion tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics import ScaledArray, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, is_scaled_leaf, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
Expand Down Expand Up @@ -74,3 +74,11 @@ def test__scaled_array__numpy_array_interface(self, npapi):
out = np.asarray(sarr)
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)

def test__is_scaled_leaf__consistent_with_jax(self):
assert is_scaled_leaf(8)
assert is_scaled_leaf(2.0)
assert is_scaled_leaf(np.array(3))
assert is_scaled_leaf(np.array([3]))
assert is_scaled_leaf(jnp.array([3]))
assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16))
101 changes: 57 additions & 44 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import chex
import jax
import jax.numpy as jnp

# import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
Expand All @@ -15,53 +17,64 @@ def test__register_scaled_op__error_if_already_registered(self):
register_scaled_op(jax.lax.mul_p, lambda a, _: a)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def test__autoscale_interpreter__proper_signature(self):
def func(x):
return x

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32)
scaled_outputs = asfunc(scaled_inputs)
expected = jnp.array([1.0, 2.0])

assert isinstance(scaled_outputs, ScaledArray)
npt.assert_array_almost_equal(scaled_outputs, expected)
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr
return x * 2

scaled_func = self.variant(autoscale(func))
scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32)
jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray
assert jaxpr.invars[0].aval.shape == scaled_inputs.shape
assert jaxpr.invars[0].aval.shape == scaled_input.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[0].aval.shape == scaled_input.shape
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_mul__no_attributes(self):
def func(x, y):
return x * y

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_convert_element_type__attributes_passing(self):
def func(x):
return jax.lax.convert_element_type(x, np.float16)

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))
x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)
out = asfunc(x)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float16
npt.assert_array_almost_equal(out, x)
@parameterized.parameters(
# Identity function!
{"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Non-trivial output JAX pytree
{"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Multi-inputs operation.
{
"fn": lambda x, y: x * y,
"inputs": [scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), scaled_array([1.5, 1.5], 2, dtype=np.float32)],
},
# Proper forwarding of attributes.
{
"fn": lambda x: jax.lax.convert_element_type(x, np.float16),
"inputs": [scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)],
},
# Proper constant scalar handling.
{
"fn": lambda x: x * 2,
"inputs": [scaled_array([[-2.0, 0.5]], 0.5, dtype=np.float32)],
},
# TODO/FIXME: Proper constant Numpy array handling.
# {
# "fn": lambda x: x * np.array([2.0, 3.0], dtype=np.float32),
# "inputs": [scaled_array([[-2.0], [0.5]], 0.5, dtype=np.float32)],
# },
)
def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, inputs):
# Autoscale function + (optional) jitting.
scaled_fn = self.variant(autoscale(fn))
scaled_output = scaled_fn(*inputs)
# Normal JAX path, without scaled arrays.
raw_inputs = [np.asarray(v) for v in inputs]
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
if not isinstance(expected_output, (np.ndarray, jax.Array)):
assert type(scaled_output) is type(expected_output)

# Check each output in the flatten tree.
scaled_outputs_flat, _ = jax.tree_util.tree_flatten(scaled_output, is_leaf=is_scaled_leaf)
expected_outputs_flat, _ = jax.tree_util.tree_flatten(expected_output)
for scaled_out, exp_out in zip(scaled_outputs_flat, expected_outputs_flat):
assert isinstance(scaled_out, ScaledArray)
assert scaled_out.scale.shape == ()
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

0 comments on commit d476f74

Please sign in to comment.