Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Dec 4, 2023
1 parent 85774af commit 22f51d5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion experiments/mnist/flax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def train_epoch(state, train_ds, batch_size, rng):

# print(set([v.primitive for v in jaxpr.eqns]))
# assert False
batch_images = jsa.as_scaled_array(batch_images)

state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels)
epoch_loss.append(loss)
Expand Down Expand Up @@ -154,7 +155,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train
state = create_train_state(init_rng, config)
# print(state)
# print(jax.tree_util.tree_flatten(state))
# state = jsa.as_scaled_array(state)
state = jsa.as_scaled_array(state)

logging.info("Start Flax MNIST training...")

Expand Down
5 changes: 5 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

def check_scalar_scales(*args: ScaledArray):
"""Check all ScaledArrays have scalar scaling."""
print(args)
for val in args:
assert np.ndim(val.scale) == 0

Expand Down Expand Up @@ -100,6 +101,7 @@ def scaled_rev(val: ScaledArray, dimensions: Sequence[int]) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> ScaledArray:
# Only supporting constant zero padding for now.
print(padding_value)
assert float(padding_value) == 0.0
return ScaledArray(lax.pad(val.data, padding_value, padding_config), val.scale)

Expand Down Expand Up @@ -135,12 +137,15 @@ def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:

@core.register_scaled_lax_op
def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: understand why/when this conversion kicks in?
lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# TODO: investigate different rule?
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
Expand Down

0 comments on commit 22f51d5

Please sign in to comment.