Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a default scale dtype to AutoScaleConfig #79

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand All @@ -25,9 +25,14 @@ class AutoScaleConfig:

NOTE: this config can be locally changed using a Python context manager:
`with AutoScaleConfig(...):`

Args:
rounding_mode: Power-of-2 rounding mode.
scale_dtype: Scale (default) datatype.
"""

rounding_mode: Pow2RoundMode = Pow2RoundMode.DOWN
scale_dtype: DTypeLike = None

def __enter__(self):
global _autoscale_config_stack
Expand Down
28 changes: 25 additions & 3 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from typing import Optional, Sequence, Union

import numpy as np
Expand All @@ -12,6 +13,7 @@
ScaledArray,
ScaledPrimitiveType,
asarray,
get_autoscale_config,
is_static_one_scalar,
register_scaled_op,
safe_div,
Expand Down Expand Up @@ -163,6 +165,11 @@ def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None)
"""


def get_scale_dtype() -> Optional[DTypeLike]:
"""Get the scale dtype, if set in the AutoScale config."""
return get_autoscale_config().scale_dtype


def get_data_scale(values: Array) -> Array:
"""`get_data_scale` primitive call method."""
return get_data_scale_p.bind(values)
Expand All @@ -171,27 +178,42 @@ def get_data_scale(values: Array) -> Array:
def get_data_scale_impl(values: Array) -> Array:
if isinstance(values, ScaledArray):
return (values.data, values.scale)
scale = np.ones((), dtype=values.dtype)
# Use array dtype for scale by default.
scale_dtype = get_scale_dtype() or values.dtype
scale = np.ones((), dtype=scale_dtype)
return values, scale


def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray:
if isinstance(values, ScaledArray):
return (values.data, values.scale)
return values, core.ShapedArray((), dtype=values.dtype)
# Use array dtype for scale by default.
scale_dtype = get_scale_dtype() or values.dtype
print(scale_dtype)
return values, core.ShapedArray((), dtype=scale_dtype)


def get_data_scale_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
# Just forwarding `values` term, adding a constant scalar scale(1).
assert len(args) == 1
scale = ir_constant(np.ones((), dtype=ctx.avals_in[0].dtype))
assert len(ctx.avals_in) == 1
assert len(ctx.avals_out) == 2
# Scale dtype "decided" during initial JAX tracing.
scale_dtype = ctx.avals_out[1].dtype
scale = ir_constant(np.ones((), dtype=scale_dtype))
return (args[0], scale)


def scaled_get_data_scale(values: ScaledArray) -> Array:
"""Scaled `get_data_scale` implementation: return scale tensor."""
scale_dtype = get_scale_dtype()
# Mis-match may potentially create issues (i.e. not equivalent scale dtype after autoscale tracer)!
if scale_dtype != values.scale.dtype:
logging.warning(
f"Autoscale config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. AutoScale graph transformation may fail because of that."
)
return values.data, values.scale


Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ def test__autoscale_config__default_values(self):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.DOWN
assert cfg.scale_dtype is None

def test__autoscale_config__context_manager(self):
with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE):
with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE, scale_dtype=np.float32):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.NONE
assert cfg.scale_dtype == np.float32
10 changes: 7 additions & 3 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import Array, ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import Array, AutoScaleConfig, ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.lax.base_scaling_primitives import (
get_data_scale,
rebalance,
Expand Down Expand Up @@ -146,13 +146,17 @@ class GetDataScalePrimitiveTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test__get_data_scale_primitive__proper_result_without_autoscale(self):
def fn(arr):
return get_data_scale(arr)
# Set a default scale dtype.
with AutoScaleConfig(scale_dtype=np.float32):
return get_data_scale(arr)

fn = self.variant(fn)
arr = jnp.array([2, 3], dtype=np.float16)
data, scale = fn(arr)
assert data.dtype == np.float16
assert scale.dtype == np.float32
npt.assert_array_equal(data, arr)
npt.assert_equal(scale, np.array(1, arr.dtype))
npt.assert_equal(scale, np.array(1, np.float32))

@chex.variants(with_jit=True, without_jit=True)
def test__get_data_scale_primitive__proper_result_with_autoscale(self):
Expand Down