-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
basic interpeter for scaled ops and scaled arrays (#8)
* basic interpeter for scaled ops and scaled arrays * precommit * mypy fixes * use ScaledArray.aval * move registry into core * Fixing module imports * imports in top level __init__.py * linting fixes * return value if in singleton list * return value if in singleton list * comment on reason for multiple vars per scaledarray --------- Co-authored-by: Paul Balanca <paulb@graphcore.ai>
- Loading branch information
Showing
7 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from . import lax | ||
from ._version import __version__ | ||
from .core import ScaledArray, autoscale # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .datatype import ScaledArray # noqa: F401 | ||
from .interpreters import autoscale, register_scaled_op # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from functools import wraps | ||
from typing import Dict | ||
|
||
import jax | ||
from jax import core | ||
from jax._src.util import safe_map | ||
|
||
from ..core import ScaledArray | ||
|
||
_scaled_ops_registry = {} | ||
|
||
|
||
def register_scaled_op(lax_func, scaled_func): | ||
_scaled_ops_registry[lax_func] = scaled_func | ||
|
||
|
||
def autoscale(fun): | ||
@wraps(fun) | ||
def wrapped(*args, **kwargs): | ||
aval_args = safe_map(lambda x: x.aval, args) | ||
# get jaxpr of unscaled graph | ||
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs) | ||
# convert to scaled graph | ||
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) | ||
return out | ||
|
||
return wrapped | ||
|
||
|
||
def autoscale_jaxpr(jaxpr, consts, *args): | ||
env: Dict[core.Var, ScaledArray] = {} | ||
|
||
def read(var): | ||
if type(var) is core.Literal: | ||
return var.val | ||
return env[var] | ||
|
||
def write(var, val): | ||
env[var] = val | ||
|
||
safe_map(write, jaxpr.invars, args) | ||
safe_map(write, jaxpr.constvars, consts) | ||
|
||
for eqn in jaxpr.eqns: | ||
invals = safe_map(read, eqn.invars) | ||
if eqn.primitive not in _scaled_ops_registry: | ||
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") | ||
outvals = _scaled_ops_registry[eqn.primitive](*invals) | ||
if not eqn.primitive.multiple_results: | ||
outvals = [outvals] | ||
safe_map(write, eqn.outvars, outvals) | ||
|
||
outvals = safe_map(read, jaxpr.outvars) | ||
if len(outvals) == 1: | ||
return outvals[0] | ||
else: | ||
return outvals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .scaled_ops import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from jax import lax | ||
|
||
from jax_scaled_arithmetics import core | ||
from jax_scaled_arithmetics.core import ScaledArray | ||
|
||
|
||
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
|
||
|
||
core.register_scaled_op(lax.mul_p, scaled_mul_p) | ||
|
||
__all__ = ["scaled_mul_p"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, autoscale | ||
|
||
|
||
class AutoScaleInterpreterTests(chex.TestCase): | ||
def test__identity(self): | ||
def func(x): | ||
return x | ||
|
||
asfunc = autoscale(func) | ||
|
||
scale = jnp.array(1.0) | ||
inputs = jnp.array([1.0, 2.0]) | ||
expected = jnp.array([1.0, 2.0]) | ||
|
||
scaled_inputs = ScaledArray(inputs, scale) | ||
scaled_outputs = asfunc(scaled_inputs) | ||
|
||
assert jnp.allclose(scaled_outputs.aval, expected) | ||
|
||
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr | ||
|
||
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray | ||
|
||
assert jaxpr.invars[0].aval.shape == inputs.shape | ||
assert jaxpr.invars[1].aval.shape == () | ||
|
||
assert jaxpr.outvars[0].aval.shape == expected.shape | ||
assert jaxpr.outvars[1].aval.shape == () | ||
|
||
def test__mul(self): | ||
def func(x, y): | ||
return x * y | ||
|
||
asfunc = autoscale(func) | ||
|
||
x_in = jnp.array([-2.0, 2.0]) | ||
x_scale = jnp.array(0.5) | ||
x = ScaledArray(x_in, x_scale) | ||
|
||
y_in = jnp.array([1.5, 1.5]) | ||
y_scale = jnp.array(2.0) | ||
y = ScaledArray(y_in, y_scale) | ||
|
||
expected = jnp.array([-3.0, 3.0]) | ||
|
||
out = asfunc(x, y) | ||
|
||
assert jnp.allclose(out.aval, expected) |