Skip to content

Commit

Permalink
Handle cast return
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 3, 2023
1 parent b750c7c commit 48cd8eb
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 21 deletions.
42 changes: 28 additions & 14 deletions enzyme_jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,23 +262,37 @@ class CpuKernel {
assert(val->operand_count() == out_shapes.size());
for (size_t i = 0; i < out_shapes.size(); i++) {
ssize_t found = -1;
for (auto &buf : assignment.Allocations()) {
if (!buf.maybe_live_out())
continue;
if (buf.is_tuple())
continue;
bool contains_output = false;
for (auto &pair : buf.assigned_buffers()) {
if (pair.first->instruction() != val->operand(i))
auto operand = val->operand(i);
while (found == -1) {
for (auto &buf : assignment.Allocations()) {
if (!buf.maybe_live_out())
continue;
if (buf.is_tuple())
continue;
assert(!contains_output);
contains_output = true;
assert(pair.second.offset == 0);
bool contains_output = false;
for (auto &pair : buf.assigned_buffers()) {
if (pair.first->instruction() != operand)
continue;
assert(!contains_output);
contains_output = true;
assert(pair.second.offset == 0);
}
if (!contains_output)
continue;
assert(found == -1);
found = buf.index();
}
if (!contains_output)
if (operand->opcode() == xla::HloOpcode::kBitcast) {
operand = operand->operand(0);
continue;
assert(found == -1);
found = buf.index();
}
break;
}
if (found == -1) {
llvm::errs() << "assignment: " << assignment.ToString() << "\n";
llvm::errs() << "val: " << val->ToString() << "\n";
llvm::errs() << "vop: " << val->operand(i)->ToString() << "\n";
llvm::errs() << "i: " << i << "\n";
}
assert(found != -1);
out_idxs.push_back((int)found);
Expand Down
2 changes: 1 addition & 1 deletion test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def fwd(in0, in1, din0, din1):

@jax.jit
def fwd_plain(in0, in1, din0, din1):
return jax.jvp(add_one, (in0, in1), (din0, din1))
return jax.jvp(add_one_plain, (in0, in1), (din0, din1))

primals, tangents = fwd(in0, in1, din0, din1)
primals, tangents = fwd_plain(in0, in1, din0, din1)
Expand Down
54 changes: 48 additions & 6 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def forward(x, config, weights, key_cache, value_cache):

key = jax.random.PRNGKey(0)
weights = {}
dweights = {}

for name, shape in [("rms_att_weight", (n_layers, dim)),
("wq", (n_layers, dim, n_heads * head_size)),
Expand All @@ -226,10 +227,14 @@ def forward(x, config, weights, key_cache, value_cache):
("wcls", (vocab_size, dim))
]:
key, subkey = jax.random.split(key)
weights[name] = jax.random.uniform(key, shape=shape)
key = subkey
key, subkey2 = jax.random.split(key)
weights[name] = jax.random.uniform(subkey, shape=shape)
dweights[name] = jax.random.uniform(subkey2, shape=shape)

x = jax.random.uniform(key, shape=(dim,))
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, shape=(dim,))
key, subkey = jax.random.split(key)
dx = jax.random.uniform(subkey, shape=(dim,))

def partial(func, config):
def sfn(x, weights, key_cache, value_cache):
Expand All @@ -240,20 +245,57 @@ def sfn(x, weights, key_cache, value_cache):
key_cache = jnp.zeros((n_layers, pos,kv_dim))
value_cache = jnp.zeros((n_layers, pos,kv_dim))

key, subkey = jax.random.split(key)
dkc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim))
key, subkey = jax.random.split(key)
dvc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim))

func = partial(forward, config)

@jax.jit
def jfunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

@enzyme_jax.enzyme_jax_ir()
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

eres = efunc(x, weights, key_cache, value_cache)[1]
print(eres)
print("Enzyme primal", eres)
res = func(x, weights, key_cache, value_cache)[1]
print(res)
print (jnp.max(jnp.abs(eres-res)))
print("Jax primal", res)
print (" max error", jnp.max(jnp.abs(eres-res)))
assert (jnp.abs(eres - res) < 1e-3).all()

#jfunc = jax.jit(partial(forward, config))
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

@jax.jit
def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

@jax.jit
def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Enzyme fwd", eres)
jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Jax fwd", jres)


@jax.jit
def jres(x, weights, kc, vc, dx, dkc, dvc):
primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc)
return f_vjp(dx, dkc, dvc)

@jax.jit
def eres(x, weights, kc, vc, dx, dkc, dvc):
primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc)
return f_vjp(dx, dkc, dvc)

eres = erev(x, weights, key_cache, dx, dkc, dvc)
print("Enzyme rev", eres)
jres = jrev(x, weights, key_cache, dx, dkc, dvc)
print("Jax rev", jres)

0 comments on commit 48cd8eb

Please sign in to comment.