Skip to content

Commit

Permalink
continuing fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 9, 2023
1 parent a999b6e commit de2c644
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "77b4fff47701a240b537a93a2e722626f7421342"
ENZYME_COMMIT = "cbb970161fd41ce55da028f0960a441382b07112"
ENZYME_SHA256 = ""

http_archive(
Expand Down
16 changes: 11 additions & 5 deletions enzyme_jax/clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,22 +620,28 @@ struct tensor<T, n0, N...>
if (II->getIntrinsicID() == llvm::Intrinsic::dbg_value)
continue;
}
if (isa<ICmpInst>(cur))
continue;
if (auto SI = dyn_cast<StoreInst>(cur)) {
assert(SI->getPointerOperand() == prev);
auto C = dyn_cast<Constant>(SI->getValueOperand());
if (auto CF = dyn_cast_or_null<ConstantFP>(C))
if (C && C->isNullValue()) {
} else if (auto CF = dyn_cast_or_null<ConstantFP>(C)) {
assert(CF->isZero());
else {
} else {
llvm::errs() << "SI: " << *SI << " C: " << *SI->getValueOperand()
<< "\n";
assert(0);
}
toErase.insert(SI);
continue;
}
llvm::errs() << " unsupported value to erase:\n";
llvm::errs() << " cur: " << *cur << " prev: " << *prev << "\n";
assert(0);
std::string err_str;
llvm::raw_string_ostream ss(err_str);
ss << *mod << "\n";
ss << " unsupported value to erase:\n";
ss << " cur: " << *cur << " prev: " << *prev << "\n";
throw pybind11::value_error(ss.str());
}
for (auto I : toErase) {
I->eraseFromParent();
Expand Down
19 changes: 18 additions & 1 deletion enzyme_jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/cpu/cpu_executable.h"

#include "Enzyme/FunctionUtils.h"

enum class ABI { Primal, Forward, Augmented, Reverse, Tape };

enum class Language : int { CPP = 0, LLVM = 1, MHLO = 2 };
Expand Down Expand Up @@ -168,8 +170,22 @@ class CpuKernel {
F->setName(fn);
assert(!F->empty());
for (auto &F2 : *linkMod)
if (!F2.empty())
if (!F2.empty()) {
F2.addFnAttr(llvm::Attribute::AlwaysInline);
// Remove invariant_load if we expect enzyme to cache explicitly all
// data. Otherwise invariant_load allows Enzyme to assume it need
// not cache, and it is illegal for us to pass in nullptr as the
// primal (since it may be needed).
if (mode == ABI::Augmented || mode == ABI::Reverse ||
mode == ABI::Tape) {
for (auto &BB : F2)
for (auto &I : BB)
if (auto LI = llvm::dyn_cast<llvm::LoadInst>(&I))
if (LI->hasMetadata(llvm::LLVMContext::MD_invariant_load))
LI->setMetadata(llvm::LLVMContext::MD_invariant_load,
nullptr);
}
}
}
ss << " extern \"C\" void " << fn
<< "(void* retval, void* run_options, void* params, void* "
Expand Down Expand Up @@ -910,6 +926,7 @@ PYBIND11_MODULE(enzyme_call, m) {
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters();
llvm::InitializeAllAsmParsers();
EnzymeAlwaysInlineDiff.setValue(true);

pybind11::enum_<Language>(m, "Language")
.value("CPP", Language::CPP)
Expand Down
19 changes: 19 additions & 0 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,22 @@ def sumrev(in0):
print(primals, grads)
assert jnp.abs(primals-50*49/2)<1e-6
assert (jnp.abs(grads[0]-1) <1e-6).all()

@enzyme_jax_ir()
def ecache(x):
return x * x[0]

@jax.jit
def cacherev(in0, din0):
primals, f_vjp = jax.vjp(ecache, in0)
grads = f_vjp(din0)
return grads

dim = 288

x = jnp.array(range(dim), dtype=jnp.float32)
dx = jnp.array(range(dim), dtype=jnp.float32)

grads = cacherev(x, dx)
assert jnp.abs(grads[0][0]-287*288*(2*287+1)/6)<1e-6
assert (jnp.abs(grads[0][1:]) <1e-6).all()
22 changes: 11 additions & 11 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ def jfunc(x, weights, key_cache, value_cache):
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

eres = efunc(x, weights, key_cache, value_cache)
print("Enzyme primal", eres)
res = func(x, weights, key_cache, value_cache)
print("Jax primal", res)
print (" max error", jnp.max(jnp.abs(eres-res)))
assert (jnp.abs(eres - res) < 1e-3).all()
# eres = efunc(x, weights, key_cache, value_cache)
# print("Enzyme primal", eres)
# res = func(x, weights, key_cache, value_cache)
# print("Jax primal", res)
# print (" max error", jnp.max(jnp.abs(eres-res)))
# assert (jnp.abs(eres - res) < 1e-3).all()

#jfunc = jax.jit(partial(forward, config))
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")
Expand All @@ -270,11 +270,11 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

print("pre fwd diff")
eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Enzyme fwd", eres)
jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Jax fwd", jres)
# print("pre fwd diff")
# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Enzyme fwd", eres)
# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Jax fwd", jres)


@jax.jit
Expand Down

0 comments on commit de2c644

Please sign in to comment.