diff --git a/BUILD b/BUILD index 286c625a..409ea10c 100644 --- a/BUILD +++ b/BUILD @@ -42,6 +42,7 @@ py_wheel( deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"], strip_path_prefixes = ["src/"], requires = [ + "absl_py >= 2.0.0", "jax >= 0.4.21", "jaxlib >= 0.4.21", ], diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index bcef48ac..957a3774 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -1,6 +1,8 @@ import jax import jax.numpy as jnp from enzyme_ad.jax import enzyme_jax_ir +from absl.testing import absltest +import timeit @enzyme_jax_ir() @@ -23,46 +25,6 @@ def add_two_plain(x: jax.Array, z, y) -> jax.Array: return x + y -in0, in1, in2 = ( - jnp.array([1.0, 2.0, 3.0]), - jnp.array([10.0, 20.0, 30.0]), - jnp.array([100.0, 200.0, 300.0]), -) -# TODO: this currently throws NYI as it is not yet connected to JIT and runtime. -# But it should print LLVM IR in the process. - -ao = add_one(in0, in1) -aop = add_one_plain(in0, in1) -assert (jnp.abs(ao - aop) < 1e-6).all() -print("Primal success") - -at = add_two(in0, in1, in2) -atp = add_two_plain(in0, in1, in2) - -assert (jnp.abs(at - atp) < 1e-6).all() -print("Primal Deadarg success") - -import timeit - -print( - timeit.Timer( - "add_one(in0, in1)", globals={"add_one": add_one, "in0": in0, "in1": in1} - ).timeit() -) -print( - timeit.Timer( - "add_one_plain(in0, in1)", - globals={"add_one_plain": add_one_plain, "in0": in0, "in1": in1}, - ).timeit() -) - -din0, din1, din2 = ( - jnp.array([0.1, 0.2, 0.3]), - jnp.array([50.0, 70.0, 110.0]), - jnp.array([1300.0, 1700.0, 1900.0]), -) - - @jax.jit def fwd(in0, in1, din0, din1): return jax.jvp(add_one, (in0, in1), (din0, din1)) @@ -73,16 +35,6 @@ def fwd_plain(in0, in1, din0, din1): return jax.jvp(add_one_plain, (in0, in1), (din0, din1)) -primals, tangents = fwd(in0, in1, din0, din1) -primals_p, tangents_p = fwd_plain(in0, in1, din0, din1) - -assert (jnp.abs(primals - primals_p) < 1e-6).all() -for t, t_p in zip(tangents, tangents_p): - assert (jnp.abs(t - t_p) < 1e-6).all() - -print("Tangent success") - - @jax.jit def fwd2(in0, in1, in2, din0, din1, din2): return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2)) @@ -93,38 +45,6 @@ def fwd2_plain(in0, in1, in2, din0, din1, din2): return jax.jvp(add_two_plain, (in0, in1, in2), (din0, din1, din2)) -primals, tangents = fwd2(in0, in1, in2, din0, din1, din2) -primals_p, tangents_p = fwd2_plain(in0, in1, in2, din0, din1, din2) - -print(primals, primals_p) -assert (jnp.abs(primals - primals_p) < 1e-6).all() -for i, (t, t_p) in enumerate(zip(tangents, tangents_p)): - print(i, t, t_p) - assert (jnp.abs(t - t_p) < 1e-6).all() - -print("Tangent deadarg success") - - -print( - timeit.Timer( - "fwd(in0, in1, din0, din1)", - globals={"fwd": fwd, "in0": in0, "in1": in1, "din0": din0, "din1": din1}, - ).timeit() -) -print( - timeit.Timer( - "fwd_plain(in0, in1, din0, din1)", - globals={ - "fwd_plain": fwd_plain, - "in0": in0, - "in1": in1, - "din0": din0, - "din1": din1, - }, - ).timeit() -) - - @jax.jit def rev(in0, in1, dout): primals, f_vjp = jax.vjp(add_one, in0, in1) @@ -139,22 +59,6 @@ def rev_plain(in0, in1, dout): return primals, grads -dout = jnp.array([500.0, 700.0, 110.0]) - -primals, grads = rev(in0, in1, dout) -# TODO enzyme will in place 0 the gradient inputs, which may not be expected -print(dout) -dout = jnp.array([500.0, 700.0, 110.0]) -primals_p, grads_p = rev_plain(in0, in1, dout) - -assert (jnp.abs(primals - primals_p) < 1e-6).all() -for g, g_p in zip(grads, grads_p): - print(i, g, g_p) - assert (jnp.abs(g - g_p) < 1e-6).all() - -print("Gradient success") - - @jax.jit def rev2(in0, in1, in2, dout): primals, f_vjp = jax.vjp(add_two, in0, in1, in2) @@ -169,35 +73,134 @@ def rev2_plain(in0, in1, in2, dout): return primals, grads -dout = jnp.array([500.0, 700.0, 110.0]) -primals, grads = rev2(in0, in1, in2, dout) -# TODO enzyme will in place 0 the gradient inputs, which may not be expected -print(dout) -dout = jnp.array([500.0, 700.0, 110.0]) -primals_p, grads_p = rev2_plain(in0, in1, in2, dout) - -assert (jnp.abs(primals - primals_p) < 1e-6).all() -for g, g_p in zip(grads, grads_p): - print(i, g, g_p) - assert (jnp.abs(g - g_p) < 1e-6).all() - -print("Gradient deadarg success") - -print( - timeit.Timer( - "rev(in0, in1, dout)", - globals={"rev": rev, "in0": in0, "in1": in1, "dout": dout}, - ).timeit() -) -print( - timeit.Timer( - "rev_plain(in0, in1, dout)", - globals={"rev_plain": rev_plain, "in0": in0, "in1": in1, "dout": dout}, - ).timeit() -) - -x = jnp.array(range(50), dtype=jnp.float32) -dx = jnp.array([i * i for i in range(50)], dtype=jnp.float32) +class AddOneTwo(absltest.TestCase): + def setUp(self): + self.in0 = jnp.array([1.0, 2.0, 3.0]) + self.in1 = jnp.array([10.0, 20.0, 30.0]) + self.in2 = jnp.array([100.0, 200.0, 300.0]) + self.din0 = jnp.array([0.1, 0.2, 0.3]) + self.din1 = jnp.array([50.0, 70.0, 110.0]) + self.din2 = jnp.array([1300.0, 1700.0, 1900.0]) + + def test_add_one_primal(self): + ao = add_one(self.in0, self.in1) + aop = add_one_plain(self.in0, self.in1) + self.assertTrue((jnp.abs(ao - aop) < 1e-6).all()) + + # Benchmark. + print( + timeit.Timer( + "add_one(in0, in1)", + globals={"add_one": add_one, "in0": self.in0, "in1": self.in1}, + ).timeit() + ) + print( + timeit.Timer( + "add_one_plain(in0, in1)", + globals={ + "add_one_plain": add_one_plain, + "in0": self.in0, + "in1": self.in1, + }, + ).timeit() + ) + + def test_add_two_deadarg(self): + at = add_two(self.in0, self.in1, self.in2) + atp = add_two_plain(self.in0, self.in1, self.in2) + self.assertTrue((jnp.abs(at - atp) < 1e-6).all()) + + def test_add_one_forward(self): + primals, tangents = fwd(self.in0, self.in1, self.din0, self.din1) + primals_p, tangents_p = fwd_plain(self.in0, self.in1, self.din0, self.din1) + + self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) + for t, t_p in zip(tangents, tangents_p): + self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) + + print( + timeit.Timer( + "fwd(in0, in1, din0, din1)", + globals={ + "fwd": fwd, + "in0": self.in0, + "in1": self.in1, + "din0": self.din0, + "din1": self.din1, + }, + ).timeit() + ) + print( + timeit.Timer( + "fwd_plain(in0, in1, din0, din1)", + globals={ + "fwd_plain": fwd_plain, + "in0": self.in0, + "in1": self.in1, + "din0": self.din0, + "din1": self.din1, + }, + ).timeit() + ) + + def test_add_two_deadarg_forward(self): + primals, tangents = fwd2( + self.in0, self.in1, self.in2, self.din0, self.din1, self.din2 + ) + primals_p, tangents_p = fwd2_plain( + self.in0, self.in1, self.in2, self.din0, self.din1, self.din2 + ) + + print(primals, primals_p) + self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) + for i, (t, t_p) in enumerate(zip(tangents, tangents_p)): + print(i, t, t_p) + self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) + + def test_add_one_reverse(self): + dout = jnp.array([500.0, 700.0, 110.0]) + + primals, grads = rev(self.in0, self.in1, dout) + # TODO enzyme will in place 0 the gradient inputs, which may not be expected + print(dout) + dout = jnp.array([500.0, 700.0, 110.0]) + primals_p, grads_p = rev_plain(self.in0, self.in1, dout) + + self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) + for i, (g, g_p) in enumerate(zip(grads, grads_p)): + print(i, g, g_p) + self.assertTrue((jnp.abs(g - g_p) < 1e-6).all()) + + print( + timeit.Timer( + "rev(in0, in1, dout)", + globals={"rev": rev, "in0": self.in0, "in1": self.in1, "dout": dout}, + ).timeit() + ) + print( + timeit.Timer( + "rev_plain(in0, in1, dout)", + globals={ + "rev_plain": rev_plain, + "in0": self.in0, + "in1": self.in1, + "dout": dout, + }, + ).timeit() + ) + + def test_add_two_deadarg_reverse(self): + dout = jnp.array([500.0, 700.0, 110.0]) + primals, grads = rev2(self.in0, self.in1, self.in2, dout) + # TODO enzyme will in place 0 the gradient inputs, which may not be expected + print(dout) + dout = jnp.array([500.0, 700.0, 110.0]) + primals_p, grads_p = rev2_plain(self.in0, self.in1, self.in2, dout) + + self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) + for i, (g, g_p) in enumerate(zip(grads, grads_p)): + print(i, g, g_p) + self.assertTrue((jnp.abs(g - g_p) < 1e-6).all()) @enzyme_jax_ir() @@ -205,22 +208,11 @@ def esum(x): return jnp.sum(x) -eres = esum(x) -print(eres) -assert jnp.abs(eres - 50 * 49 / 2) < 1e-6 - - @jax.jit def sumfwd(in0, din0): return jax.jvp(esum, (in0,), (din0,)) -primals, tangents = sumfwd(x, dx) -print(primals, tangents) -assert jnp.abs(primals - 50 * 49 / 2) < 1e-6 -assert jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6 - - @jax.jit def sumrev_p(in0): primals, f_vjp = jax.vjp(jnp.sum, in0) @@ -228,10 +220,6 @@ def sumrev_p(in0): return primals, grads -primals, grads = sumrev_p(x) -print(primals, grads) - - @jax.jit def sumrev(in0): primals, f_vjp = jax.vjp(esum, in0) @@ -239,10 +227,31 @@ def sumrev(in0): return primals, grads -primals, grads = sumrev(x) -print(primals, grads) -assert jnp.abs(primals - 50 * 49 / 2) < 1e-6 -assert (jnp.abs(grads[0] - 1) < 1e-6).all() +class Sum(absltest.TestCase): + def setUp(self): + self.x = jnp.array(range(50), dtype=jnp.float32) + self.dx = jnp.array([i * i for i in range(50)], dtype=jnp.float32) + + def test_primal(self): + eres = esum(self.x) + print(eres) + self.assertTrue(jnp.abs(eres - 50 * 49 / 2) < 1e-6) + + def test_forward(self): + primals, tangents = sumfwd(self.x, self.dx) + print(primals, tangents) + self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6) + self.assertTrue(jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6) + + def test_reverse_p(self): + primals, grads = sumrev_p(self.x) + print(primals, grads) + + def test_reverse(self): + primals, grads = sumrev(self.x) + print(primals, grads) + self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6) + self.assertTrue((jnp.abs(grads[0] - 1) < 1e-6).all()) @enzyme_jax_ir() @@ -257,11 +266,19 @@ def cacherev(in0, din0): return grads -dim = 288 +class Cache(absltest.TestCase): + def test_reverse(self): + dim = 288 + + x = jnp.array(range(dim), dtype=jnp.float32) + dx = jnp.array(range(dim), dtype=jnp.float32) + + grads = cacherev(x, dx) + self.assertTrue( + jnp.abs(grads[0][0] - (dim - 1) * dim * (2 * (dim - 1) + 1) / 6) < 1e-6 + ) + self.assertTrue((jnp.abs(grads[0][1:]) < 1e-6).all()) -x = jnp.array(range(dim), dtype=jnp.float32) -dx = jnp.array(range(dim), dtype=jnp.float32) -grads = cacherev(x, dx) -assert jnp.abs(grads[0][0] - 287 * 288 * (2 * 287 + 1) / 6) < 1e-6 -assert (jnp.abs(grads[0][1:]) < 1e-6).all() +if __name__ == "__main__": + absltest.main() diff --git a/test/llama.py b/test/llama.py index cc3649ce..53994d19 100644 --- a/test/llama.py +++ b/test/llama.py @@ -1,7 +1,9 @@ +from absl.testing import absltest import jax.numpy as jnp import jax.random import jax.lax import enzyme_ad.jax as enzyme_jax +import numpy as np def rmsnorm(x, weight): @@ -213,126 +215,119 @@ def forward(x, config, weights, key_cache, value_cache): return x -import numpy as np - -config = { - "dim": 288, - "hidden_dim": 768, - "n_layers": 6, - "n_heads": 6, - "n_kv_heads": 6, - "vocab_size": 32000, - "seq_len": 256, -} - -n_layers = config["n_layers"] -seq_len = config["seq_len"] -n_heads = config["n_heads"] -dim = config["dim"] -n_kv_heads = config["n_kv_heads"] -vocab_size = config["vocab_size"] -hidden_dim = config["hidden_dim"] -kv_dim = dim // n_heads * n_kv_heads -head_size = dim // n_heads - -key = jax.random.PRNGKey(0) -weights = {} -dweights = {} - -for name, shape in [ - ("rms_att_weight", (n_layers, dim)), - ("wq", (n_layers, dim, n_heads * head_size)), - ("wk", (n_layers, dim, n_kv_heads * head_size)), - ("wv", (n_layers, dim, n_kv_heads * head_size)), - ("wo", (n_layers, dim, dim)), - ("rms_ffn_weight", (n_layers, dim)), - ("w1", (n_layers, hidden_dim, dim)), - ("w2", (n_layers, dim, hidden_dim)), - ("w3", (n_layers, hidden_dim, dim)), - ("rms_final_weight", (dim,)), - ("wcls", (vocab_size, dim)), -]: - key, subkey = jax.random.split(key) - key, subkey2 = jax.random.split(key) - weights[name] = jax.random.uniform(subkey, shape=shape) - dweights[name] = jax.random.uniform(subkey2, shape=shape) - -key, subkey = jax.random.split(key) -x = jax.random.uniform(subkey, shape=(dim,)) -key, subkey = jax.random.split(key) -dx = jax.random.uniform(subkey, shape=(dim,)) - - -def partial(func, config): - def sfn(x, weights, key_cache, value_cache): - return func(x, config, weights, key_cache, value_cache) - - return sfn - - -pos = 1 -key_cache = jnp.zeros((n_layers, pos, kv_dim)) -value_cache = jnp.zeros((n_layers, pos, kv_dim)) - -key, subkey = jax.random.split(key) -dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) -key, subkey = jax.random.split(key) -dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) - -func = partial(forward, config) - - -@jax.jit -def jfunc(x, weights, key_cache, value_cache): - return func(x, weights, key_cache, value_cache) - - -@enzyme_jax.enzyme_jax_ir() -def efunc(x, weights, key_cache, value_cache): - return func(x, weights, key_cache, value_cache) - - -# eres = efunc(x, weights, key_cache, value_cache) -# print("Enzyme primal", eres) -# res = func(x, weights, key_cache, value_cache) -# print("Jax primal", res) -# print (" max error", jnp.max(jnp.abs(eres-res))) -# assert (jnp.abs(eres - res) < 1e-3).all() - -# jfunc = jax.jit(partial(forward, config)) -# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") - - -@jax.jit -def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - -@jax.jit -def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - -# print("pre fwd diff") -# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) -# print("Enzyme fwd", eres) -# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) -# print("Jax fwd", jres) - - -@jax.jit -def jrev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - -@jax.jit -def erev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - -eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) -print("Enzyme rev", eres) -jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) -print("Jax rev", jres) +class Llama(absltest.TestCase): + def test_llama_random(self): + config = { + "dim": 288, + "hidden_dim": 768, + "n_layers": 6, + "n_heads": 6, + "n_kv_heads": 6, + "vocab_size": 32000, + "seq_len": 256, + } + + n_layers = config["n_layers"] + seq_len = config["seq_len"] + n_heads = config["n_heads"] + dim = config["dim"] + n_kv_heads = config["n_kv_heads"] + vocab_size = config["vocab_size"] + hidden_dim = config["hidden_dim"] + kv_dim = dim // n_heads * n_kv_heads + head_size = dim // n_heads + + key = jax.random.PRNGKey(0) + weights = {} + dweights = {} + + for name, shape in [ + ("rms_att_weight", (n_layers, dim)), + ("wq", (n_layers, dim, n_heads * head_size)), + ("wk", (n_layers, dim, n_kv_heads * head_size)), + ("wv", (n_layers, dim, n_kv_heads * head_size)), + ("wo", (n_layers, dim, dim)), + ("rms_ffn_weight", (n_layers, dim)), + ("w1", (n_layers, hidden_dim, dim)), + ("w2", (n_layers, dim, hidden_dim)), + ("w3", (n_layers, hidden_dim, dim)), + ("rms_final_weight", (dim,)), + ("wcls", (vocab_size, dim)), + ]: + key, subkey = jax.random.split(key) + key, subkey2 = jax.random.split(key) + weights[name] = jax.random.uniform(subkey, shape=shape) + dweights[name] = jax.random.uniform(subkey2, shape=shape) + + key, subkey = jax.random.split(key) + x = jax.random.uniform(subkey, shape=(dim,)) + key, subkey = jax.random.split(key) + dx = jax.random.uniform(subkey, shape=(dim,)) + + def partial(func, config): + def sfn(x, weights, key_cache, value_cache): + return func(x, config, weights, key_cache, value_cache) + + return sfn + + pos = 1 + key_cache = jnp.zeros((n_layers, pos, kv_dim)) + value_cache = jnp.zeros((n_layers, pos, kv_dim)) + + key, subkey = jax.random.split(key) + dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + key, subkey = jax.random.split(key) + dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + + func = partial(forward, config) + + @jax.jit + def jfunc(x, weights, key_cache, value_cache): + return func(x, weights, key_cache, value_cache) + + @enzyme_jax.enzyme_jax_ir() + def efunc(x, weights, key_cache, value_cache): + return func(x, weights, key_cache, value_cache) + + # eres = efunc(x, weights, key_cache, value_cache) + # print("Enzyme primal", eres) + # res = func(x, weights, key_cache, value_cache) + # print("Jax primal", res) + # print (" max error", jnp.max(jnp.abs(eres-res))) + # assert (jnp.abs(eres - res) < 1e-3).all() + + # jfunc = jax.jit(partial(forward, config)) + # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") + + @jax.jit + def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + + @jax.jit + def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + + # print("pre fwd diff") + # eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) + # print("Enzyme fwd", eres) + # jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) + # print("Jax fwd", jres) + + @jax.jit + def jrev(x, weights, kc, vc, dx, dkc, dvc): + primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) + return f_vjp(dx) # , dkc, dvc) + + @jax.jit + def erev(x, weights, kc, vc, dx, dkc, dvc): + primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) + return f_vjp(dx) # , dkc, dvc) + + eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) + print("Enzyme rev", eres) + jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) + print("Jax rev", jres) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/test.py b/test/test.py index 1553f7d6..2c58cc27 100644 --- a/test/test.py +++ b/test/test.py @@ -1,85 +1,89 @@ +from absl.testing import absltest import jax import jax.numpy as jnp -from enzyme_ad.jax import cpp_call +from enzyme_ad.jax import cpp_call, enzyme_jax_ir -@jax.jit -def do_something(ones): - shape = jax.core.ShapedArray(ones.shape, ones.dtype) - a, b = cpp_call( - ones, - out_shapes=[shape, shape], - source=""" - template - void myfn(enzyme::tensor& out0, enzyme::tensor& out1, const enzyme::tensor& in0) { - for (int j=0; j + void myfn(enzyme::tensor& out0, + enzyme::tensor& out1, + const enzyme::tensor& in0) { + for (int j=0; j + void f(T1& out0, const T2& in1) { + out0 = 56.0f; } - for (int j=0; j<2; j++) { - for (int k=0; k<3; k++) { - out1[j][k] = in0[j][k] + 2 * 42; - } - } - } - """, - fn="myfn", - ) - c = cpp_call( - a, - out_shapes=[jax.core.ShapedArray([4, 4], jnp.float32)], - source=""" - template - void f(T1& out0, const T2& in1) { - out0 = 56.0f; - } - """, - ) - return a, b, c - - -ones = jnp.ones((2, 3), jnp.float32) -x, y, z = do_something(ones) - -print(x) -print(y) -print(z) + """, + ) + return a, b, c -primals, tangents = jax.jvp(do_something, (ones,), (ones,)) -print(primals) -print(tangents) + ones = jnp.ones((2, 3), jnp.float32) + x, y, z = do_something(ones) + print(x) + print(y) + print(z) -primals, f_vjp = jax.vjp(do_something, ones) -(grads,) = f_vjp((x, y, z)) -print(primals) -print(grads) + # JVP + primals, tangents = jax.jvp(do_something, (ones,), (ones,)) + print(primals) + print(tangents) -# Test enzyme mlir jit -from enzyme_ad.jax import enzyme_jax_ir + # VJP + primals, f_vjp = jax.vjp(do_something, ones) + (grads,) = f_vjp((x, y, z)) + print(primals) + print(grads) + def test_enzyme_mlir_jit(self): + @enzyme_jax_ir() + def add_one(x: jax.Array, y) -> jax.Array: + return x + 1 + y -@enzyme_jax_ir() -def add_one(x: jax.Array, y) -> jax.Array: - return x + 1 + y + # But it should print LLVM IR in the process. + add_one(jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])) + primals, tangents = jax.jvp( + add_one, + (jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])), + (jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])), + ) + print(primals) + print(tangents) -# TODO: this currently throws NYI as it is not yet connected to JIT and runtime. -# But it should print LLVM IR in the process. -add_one(jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])) + primals, f_vjp = jax.vjp( + add_one, jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0]) + ) + grads = f_vjp(jnp.array([500.0, 700.0, 110.0])) + print(primals) + print(grads) -primals, tangents = jax.jvp( - add_one, - (jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])), - (jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])), -) -print(primals) -print(tangents) -primals, f_vjp = jax.vjp( - add_one, jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0]) -) -grads = f_vjp(jnp.array([500.0, 700.0, 110.0])) -print(primals) -print(grads) +if __name__ == "__main__": + absltest.main()