Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwarding broadcasted scalar metadata in Scalify tracer. #97

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def as_scaled_array_base(
if isinstance(val, ScaledArray):
return val

assert scale is None or scale_dtype is None
# Simple case => when can ignore the scaling factor (i.e. 1 implicitely).
is_static_one_scale: bool = scale is None or is_static_one_scalar(scale) # type:ignore
# Trivial cases: bool, int, float.
Expand All @@ -189,12 +190,15 @@ def as_scaled_array_base(

scale_dtype = scale_dtype or val.dtype
scale = np.array(1, dtype=scale_dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
if isinstance(val, (np.ndarray, *ArrayTypes)):
if is_static_one_scale:
return ScaledArray(val, scale)
else:
return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore
return scaled_array_base(val, scale)

# TODO: fix bug when scale is not 1.
raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.")
# return scaled_array_base(val, scale)


def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
Expand Down
114 changes: 89 additions & 25 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union

import jax
import numpy as np
from jax import core
from jax import core, lax
from jax._src.custom_derivatives import (
custom_jvp_call_jaxpr_p,
custom_jvp_call_p,
Expand Down Expand Up @@ -82,6 +82,15 @@ class ScaledPrimitiveType(IntEnum):
"""


_scalar_preserving_primitives: Set[core.Primitive] = set()
"""Scalar preserving JAX primitives

More specifically: if all inputs are (broadcasted) scalars, then the output(s)
are broadcasted scalars. Keeping track of broadcasted scalars is allowing
proper conversion to ScaledArrays (instead of assigning default scale 1).
"""


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
prim_name = scaled_func.__name__.replace("scaled_", "") + "_p"
Expand All @@ -107,15 +116,6 @@ def _get_data(val: Any) -> Array:
return val


def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Promote a scalar (Numpy, JAX, ...) to a Scaled Array.

Note: needs to work with any input type, including JAX tracer ones.
"""
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val, scale_dtype=scale_dtype)


def register_scaled_op(
prim: core.Primitive, scaled_func: Any, scaled_type: ScaledPrimitiveType = ScaledPrimitiveType.FORWARD
) -> None:
Expand Down Expand Up @@ -160,15 +160,6 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val


@register_pytree_node_class
@dataclass(frozen=True, init=False)
class ScalifyTracerArray:
Expand Down Expand Up @@ -210,6 +201,10 @@ def tree_unflatten(cls, aux_data, children):
assert len(children) == 1
return cls(children[0], aux_data[0])

@property
def dtype(self) -> DTypeLike:
return self.array.dtype

@property
def size(self) -> int:
return self.array.size
Expand All @@ -223,12 +218,36 @@ def is_scaled_array(self) -> bool:
return isinstance(self.array, ScaledArray)

def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
if self.is_scaled_array:
"""(Tentatively) converting to a scaled array.

Supporting the following cases:
- scalar array;
- broadcasted scalar array;

Not supporting:
- bool/int dtypes;
- any other array;

TODO: support (constant) Numpy arrays.
"""
# Already scaled array, or not a floating point dtype.
if isinstance(self.array, ScaledArray) or not np.issubdtype(self.dtype, np.floating):
return self.array
# TODO: improve the logic for broadcasted scalar arrays!
return promote_to_scaled_array(self.array, scale_dtype)

if np.ndim(self.array) == 0:
# Single value => "easy case".
return as_scaled_array_base(self.array, scale_dtype=scale_dtype)
elif self.is_broadcasted_scalar:
# Broadcasted scalar => convert as a scalar.
scalar_val = self.array.ravel()[0]
scaled_scalar = as_scaled_array_base(scalar_val, scale_dtype=scale_dtype)
return as_scaled_array_base(self.array, scale=scaled_scalar.scale)

# No promotion rule found => just return as such.
return self.array

def to_array(self) -> Array:
"""Converting to a (normal) JAX/Numpy array."""
if not self.is_scaled_array:
return self.array
return self.array.to_array()
Expand Down Expand Up @@ -303,6 +322,7 @@ def write(var, val: ScalifyTracerArray):

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
assert len(jaxpr.constvars) == len(consts)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

Expand All @@ -321,12 +341,17 @@ def write(var, val: ScalifyTracerArray):
any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER))
# Are outputs broadcasted scalars?
are_outputs_broadcasted_scalars = (
all([v.is_broadcasted_scalar for v in invals_tracer]) and eqn.primitive in _scalar_preserving_primitives
)
scalify_array_init_fn = lambda v: ScalifyTracerArray(v, is_broadcasted_scalar=are_outputs_broadcasted_scalars)

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
invals = [v.to_array() for v in invals_tracer]
outvals = jaxpr_eqn_bind(eqn, invals)
outvals_tracer = list(map(ScalifyTracerArray, outvals))
outvals_tracer = list(map(scalify_array_init_fn, outvals))
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
Expand All @@ -337,7 +362,7 @@ def write(var, val: ScalifyTracerArray):
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
outvals_tracer = list(map(ScalifyTracerArray, outvals))
outvals_tracer = list(map(scalify_array_init_fn, outvals))

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
Expand Down Expand Up @@ -451,3 +476,42 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any)

_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation
_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation


# Default collection of scalar preserving JAX primitives.
_scalar_preserving_primitives |= {
lax.abs_p,
lax.acos_p,
lax.acosh_p,
lax.add_p,
lax.asin_p,
lax.asinh_p,
lax.atan_p,
lax.atan2_p,
lax.atanh_p,
lax.bitcast_convert_type_p,
lax.broadcast_in_dim_p,
lax.cbrt_p,
lax.clamp_p,
lax.convert_element_type_p,
lax.integer_pow_p,
lax.min_p,
lax.max_p,
lax.mul_p,
lax.neg_p,
lax.reduce_prod_p,
lax.reduce_sum_p,
lax.reduce_max_p,
lax.reduce_min_p,
lax.reduce_precision_p,
lax.reshape_p,
lax.rem_p,
lax.slice_p,
lax.sin_p,
lax.sinh_p,
lax.sub_p,
lax.sqrt_p,
lax.tan_p,
lax.tanh_p,
lax.transpose_p,
}
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# TODO: investigate different rule?
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)

Expand Down
3 changes: 1 addition & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from jax_scaled_arithmetics.core import (
DTypeLike,
ScaledArray,
as_scaled_array,
get_autoscale_config,
pow2_round,
register_scaled_op,
Expand All @@ -23,7 +22,7 @@
def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArray:
"""Scaled add/sub generic implementation."""
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
# A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
Expand Down
13 changes: 10 additions & 3 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,21 @@ def test__as_scaled_array__unscaled_bool_int_output(self, data):
output = as_scaled_array(data)
assert output is data

@chex.variants(with_jit=True, without_jit=True)
def test__as_scaled_array__complex_pytree(self):
input = {"x": jnp.array([1, 2]), "y": jnp.array([1.0, 2]), "z": as_scaled_array(jnp.array([1.0, 2]))}
output = as_scaled_array(input)
output = self.variant(as_scaled_array)(input, scale=np.float32(2))
assert isinstance(output, dict)
assert len(output) == 3
assert output["x"] is input["x"]

npt.assert_array_equal(output["x"], input["x"])
npt.assert_array_equal(output["y"], input["y"])
assert output["z"] is input["z"]
npt.assert_array_equal(output["z"], input["z"])
npt.assert_almost_equal(output["y"].scale, 2)

if "without_jit" in self.variant.__name__:
assert output["x"] is input["x"]
assert output["z"] is input["z"]

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
Expand Down
76 changes: 52 additions & 24 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
register_scaled_op,
scaled_array,
)
from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray, promote_scalar_to_scaled_array
from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray


class ScalifyTracerArrayTests(chex.TestCase):
Expand Down Expand Up @@ -73,6 +73,38 @@ def test__scalify_tracer_array__flatten__proper_pytree(self):
assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar
npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array))

@parameterized.parameters(
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__scalify_tracer_array__to_scaled_array__scalar_input(self, input):
scaled_val = ScalifyTracerArray(input).to_scaled_array()
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
# NOTE: scale is a power-of-two.
npt.assert_almost_equal(np.asarray(scaled_val), input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": np.int32(2)},
)
def test__scalify_tracer_array__to_scaled_array__not_promoted_input(self, input):
out = ScalifyTracerArray(input).to_scaled_array()
assert out is input

def test__scalify_tracer_array__to_scaled_array__broadcasted_scalar_input(self):
data = np.array([5, 5], dtype=np.float16)
scaled_out = ScalifyTracerArray(data, is_broadcasted_scalar=True).to_scaled_array(scale_dtype=np.float32)

assert isinstance(scaled_out, ScaledArray)
assert scaled_out.dtype == data.dtype
assert scaled_out.scale.dtype == np.float32
npt.assert_almost_equal(scaled_out.scale, 4)
npt.assert_array_equal(np.asarray(scaled_out), data)


class AutoScaleInterpreterTests(chex.TestCase):
def test__register_scaled_op__error_if_already_registered(self):
Expand Down Expand Up @@ -195,6 +227,25 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__promotion_broadcasted_scalar_array(self):
def fn(sa, b):
# Forcing broadcasting before the `lax.mul`
b = jax.lax.broadcast_in_dim(b, sa.shape, ())
return sa * b

sa = scaled_array([0.5, 1.0], np.float32(4.0), dtype=np.float32)
b = jnp.array(4.0, dtype=np.float16)

scaled_fn = self.variant(autoscale(fn))
sout = scaled_fn(sa, b)
expected_out = fn(np.asarray(sa), b)

assert isinstance(sout, ScaledArray)
# Proper output scale, with `b` treated as scaled scalar.
npt.assert_equal(np.asarray(sout.scale), np.float32(16))
npt.assert_array_equal(np.asarray(sout), expected_out)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self):
# JAX official `jvp` example.
Expand Down Expand Up @@ -264,29 +315,6 @@ def fn(x, y):
assert isinstance(sg, ScaledArray)
npt.assert_array_almost_equal(sg, g)

@parameterized.parameters(
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input):
scaled_val = promote_scalar_to_scaled_array(input)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
# NOTE: scale is a power-of-two.
npt.assert_almost_equal(np.asarray(scaled_val), input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": np.int32(2)},
)
def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input):
out = promote_scalar_to_scaled_array(input)
assert out is input

def test__autoscale_config__default_values(self):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
Expand Down
Loading
Loading