Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 18, 2024
1 parent 478cf62 commit 3c21e05
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
54 changes: 52 additions & 2 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
from functools import partial

# from jax.scipy.special import logsumexp

def print_mean_std(name, v):
data, scale = jsa.lax.get_data_scale(v)
# Always use np.float32, to avoid floating errors in descaling + stats.
data = jsa.asarray(data, dtype=np.float32)
m, s, min, max = np.mean(data), np.std(data), np.min(data), np.max(data)
print(f"{name}: MEAN({m:.5f}) / STD({s:.5f}) / MIN({min:.5f}) / MAX({max:.5f}) / SCALE({scale:.5f})")


def logsumexp(a, axis=None, keepdims=False):
dims = (axis,)
Expand All @@ -47,6 +55,27 @@ def logsumexp(a, axis=None, keepdims=False):
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# jax.nn.logsumexp

def one_hot_dot(logits, mask, axis: int):
size = logits.shape[axis]

mask = jsa.lax.rebalance(mask, np.float32(1./8.))

jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
jsa.ops.debug_callback(partial(print_mean_std, "Mask"), mask)

r = jnp.sum(logits * mask, axis=axis)
jsa.ops.debug_callback(partial(print_mean_std, "Out"), r)
print("SIZE:", size, jsa.core.pow2_round_down(np.float32(size)))
(r,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "OutGrad"), r)


return r




def predict(params, inputs):
activations = inputs
Expand All @@ -58,14 +87,32 @@ def predict(params, inputs):
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
# logits = jsa.ops.dynamic_rescale_l2_grad(logits)

# print("LOGITS", logits)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad0"), logits)
logsumlogits = logsumexp(logits, axis=1, keepdims=True)
# (logsumlogits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsLogSumGrad"), logsumlogits)
logits = logits - logsumlogits
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad1"), logits)

# logits = jsa.ops.dynamic_rescale_l1_grad(logits)
return logits


def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
loss = one_hot_dot(preds, targets, axis=1)
# loss = jnp.sum(preds * targets, axis=1)s
# loss = jsa.ops.dynamic_rescale_l1_grad(loss)
(loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad2"), loss)
loss = -jnp.mean(loss)
jsa.ops.debug_callback(partial(print_mean_std, "Loss"), loss)
(loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad"), loss)


return loss
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand Down Expand Up @@ -111,6 +158,9 @@ def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

num_batches = 2
num_epochs = 1

for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
data_sq = jax.lax.abs(data)
axes = tuple(range(data.ndim))
# Get MAX norm + pow2 rounding.
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) / (64)
norm = jax.lax.max(pow2_round(norm).astype(scale.dtype), eps.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)
Expand Down

0 comments on commit 3c21e05

Please sign in to comment.