Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Nov 24, 2023
1 parent 18ae06b commit 6aa5ed1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
7 changes: 6 additions & 1 deletion experiments/mnist/mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,17 @@ def update(i, opt_state, batch):

# print(opt_state)
for _ in range(num_batches):
batch = jax.tree_map(lambda v: jsa.scaled_array(v, scale=np.float32(1.0)), next(batches))
batch = next(batches)
batch = jax.tree_map(lambda v: jsa.scaled_array(v, scale=np.float32(1.0)), batch)
# batch = jsa.scaled_array(next(batches), np.float32(0))
opt_state = update(next(itercount), opt_state, batch)
epoch_time = time.time() - start_time

params = get_params(opt_state)

params = jax.tree_map(lambda v: jsa.core.asarray(v), params, is_leaf=jsa.core.is_scaled_leaf)
# print(params)

train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
Expand Down
10 changes: 9 additions & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, asarray, is_scaled_leaf, scaled_array # noqa: F401
from .datatype import ( # noqa: F401
DTypeLike,
ScaledArray,
Shape,
as_scaled_array,
asarray,
is_scaled_leaf,
scaled_array,
)
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
Expand Down
7 changes: 7 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
return ScaledArray(data, scale)


def as_scaled_array(val: Any) -> ScaledArray:
""" """
if isinstance(val, ScaledArray):
return val
return ScaledArray(data=val, scale=np.array(1, dtype=val.dtype))


def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array.
Expand Down
9 changes: 8 additions & 1 deletion jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape, as_scaled_array, register_scaled_op

from .base_scaling_primitives import scaled_set_scaling

Expand Down Expand Up @@ -87,6 +88,9 @@ def scaled_reshape(A: ScaledArray, new_sizes: Sequence[int], dimensions: Sequenc

@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# print("INPUTS:", A, B)
A = as_scaled_array(A)
B = as_scaled_array(B)
return ScaledArray(A.data * B.data, A.scale * B.scale)


Expand All @@ -102,6 +106,9 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(output_data, output_scale)


register_scaled_op(add_any_p, scaled_add)


@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
check_scalar_scales(A, B)
Expand Down

0 comments on commit 6aa5ed1

Please sign in to comment.