-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Forwarding broadcasted scalar metadata in Scalify tracer.
`scalify` interpreter/tracer is now properly tracking which tensors are just broadcasted scalars, helping then to refine the conversion rule to ScaledArray for these. In practice: it means (finally!) proper full scale propagation in MNIST training, resulting in stable training with dynamic rescale. TODO: we still need to understand why `scaled_mul` requires ScaledArray promotion to get the MNIST training example running. This requirement has been lifted in `div/add/sub` thanks to this PR.
- Loading branch information
Showing
6 changed files
with
176 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array | ||
|
||
|
||
class ScaledJaxNumpyFunctions(chex.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
# Use random state for reproducibility! | ||
self.rs = np.random.RandomState(42) | ||
|
||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__numpy_mean__proper_gradient_scale_propagation(self): | ||
def mean_fn(x): | ||
# Taking the square to "force" ScaledArray gradient. | ||
# Numpy mean constant rescaling creating trouble on backward pass! | ||
return jax.grad(lambda v: jnp.mean(v * v))(x) | ||
|
||
# size = 8 * 16 | ||
input_scaled = scaled_array(self.rs.rand(8, 16).astype(np.float32), np.float32(1)) | ||
output_grad_scaled = self.variant(autoscale(mean_fn))(input_scaled) | ||
|
||
assert isinstance(output_grad_scaled, ScaledArray) | ||
# Proper scale propagation on the backward pass (rough interval) | ||
assert np.std(output_grad_scaled.data) >= 0.25 | ||
assert np.std(output_grad_scaled.data) <= 1.0 | ||
# "small" scale. | ||
assert output_grad_scaled.scale <= 0.01 |