Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ScaledArray aval and create factory method. #10

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, autoscale # noqa: F401
from .core import ScaledArray, autoscale, scaled_array # noqa: F401
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray # noqa: F401
from .datatype import ScaledArray, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
25 changes: 22 additions & 3 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Any, Union

import jax
import jax.numpy as jnp
import numpy as np
from chex import Shape
from jax.core import ShapedArray
from jax.tree_util import register_pytree_node_class
from numpy.typing import DTypeLike, NDArray
from numpy.typing import ArrayLike, DTypeLike, NDArray

GenericArray = Union[jax.Array, np.ndarray]

Expand Down Expand Up @@ -80,5 +82,22 @@ def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]:
return np.asarray(self.to_array(dtype))

@property
def aval(self):
return self.data * self.scale
def aval(self) -> ShapedArray:
"""Abstract value of the scaled array, i.e. shape and dtype."""
return ShapedArray(self.data.shape, self.data.dtype)


def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray:
"""ScaledArray (helper) factory method, similar to `(j)np.array`.

Args:
data: Main data/values.
scale: Scale tensor.
dtype: Optional dtype to use for the data.
npapi: Numpy API to use.
Returns:
Scaled array instance.
"""
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)
59 changes: 39 additions & 20 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,73 @@
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics import ScaledArray, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__init__multi_numpy_backend(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(1))
assert isinstance(sarr.data, npb.ndarray)
assert isinstance(sarr.scale, npb.ndarray)
def test__scaled_array__init__multi_numpy_backend(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1))
assert isinstance(sarr.data, npapi.ndarray)
assert isinstance(sarr.scale, npapi.ndarray)
assert sarr.scale.shape == ()

def test__scaled_array__basic_properties(self):
sarr = ScaledArray(data=jnp.array([1.0, 2.0]), scale=jnp.array(1))
@parameterized.parameters(
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__factory_method__multi_numpy_backend(self, npapi):
sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi)
assert isinstance(sarr, ScaledArray)
assert isinstance(sarr.data, npapi.ndarray)
assert isinstance(sarr.scale, npapi.ndarray)
assert sarr.data.dtype == ShapedArray((2,), np.float16)
assert sarr.scale.shape == ()
npt.assert_array_almost_equal(sarr, [3, 6])

@parameterized.parameters(
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__basic_properties(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1))
assert sarr.dtype == np.float32
assert sarr.shape == (2,)
assert sarr.aval == ShapedArray((2,), np.float32)

@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__to_array__multi_numpy_backend(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float16), scale=npb.array(3))
def test__scaled_array__to_array__multi_numpy_backend(self, npapi):
sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi)
# No dtype specified.
out = sarr.to_array()
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == sarr.dtype
npt.assert_array_equal(out, sarr.data * sarr.scale)
# Custom float dtype.
out = sarr.to_array(dtype=np.float32)
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == np.float32
npt.assert_array_equal(out, sarr.data * sarr.scale)
# Custom int dtype.
out = sarr.to_array(dtype=np.int8)
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == np.int8
npt.assert_array_equal(out, sarr.data * sarr.scale)

@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__numpy_array_interface(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(3))
def test__scaled_array__numpy_array_interface(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(3))
out = np.asarray(sarr)
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale
from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
Expand All @@ -14,20 +16,16 @@ def func(x):

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32)
scaled_outputs = asfunc(scaled_inputs)
expected = jnp.array([1.0, 2.0])

assert jnp.allclose(scaled_outputs.aval, expected)

assert isinstance(scaled_outputs, ScaledArray)
npt.assert_array_almost_equal(scaled_outputs, expected)
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray

assert jaxpr.invars[0].aval.shape == inputs.shape
assert jaxpr.invars[0].aval.shape == scaled_inputs.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
Expand All @@ -39,16 +37,10 @@ def func(x, y):

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)

assert jnp.allclose(out.aval, expected)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)