From 22f51d549bcc1a1fff34d3d31bd639470bcb5f64 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 1 Dec 2023 17:09:56 +0000 Subject: [PATCH] wip --- experiments/mnist/flax/train.py | 3 ++- jax_scaled_arithmetics/lax/scaled_ops.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/experiments/mnist/flax/train.py b/experiments/mnist/flax/train.py index 0c520e8..63b876f 100644 --- a/experiments/mnist/flax/train.py +++ b/experiments/mnist/flax/train.py @@ -104,6 +104,7 @@ def train_epoch(state, train_ds, batch_size, rng): # print(set([v.primitive for v in jaxpr.eqns])) # assert False + batch_images = jsa.as_scaled_array(batch_images) state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels) epoch_loss.append(loss) @@ -154,7 +155,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train state = create_train_state(init_rng, config) # print(state) # print(jax.tree_util.tree_flatten(state)) - # state = jsa.as_scaled_array(state) + state = jsa.as_scaled_array(state) logging.info("Start Flax MNIST training...") diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 6fa6ea6..1651aa2 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -16,6 +16,7 @@ def check_scalar_scales(*args: ScaledArray): """Check all ScaledArrays have scalar scaling.""" + print(args) for val in args: assert np.ndim(val.scale) == 0 @@ -100,6 +101,7 @@ def scaled_rev(val: ScaledArray, dimensions: Sequence[int]) -> ScaledArray: @core.register_scaled_lax_op def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> ScaledArray: # Only supporting constant zero padding for now. + print(padding_value) assert float(padding_value) == 0.0 return ScaledArray(lax.pad(val.data, padding_value, padding_config), val.scale) @@ -135,12 +137,15 @@ def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + # TODO: understand why/when this conversion kicks in? + lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore # TODO: investigate different rule? return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale) @core.register_scaled_lax_op def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: + A, B = as_scaled_array((A, B)) # type:ignore check_scalar_scales(A, B) A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating)