From 4ee70bd3f27c59bd2cb9bad29342edd21e989577 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Dec 2023 23:55:31 -0500 Subject: [PATCH] continuing fixups --- WORKSPACE | 2 +- enzyme_jax/clang_compile.cc | 16 +++++++++++----- enzyme_jax/enzyme_call.cc | 19 ++++++++++++++++++- test/bench_vs_xla.py | 19 +++++++++++++++++++ test/llama.py | 22 +++++++++++----------- 5 files changed, 60 insertions(+), 18 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 0b0a5550..d7669336 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -60,7 +60,7 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "77b4fff47701a240b537a93a2e722626f7421342" +ENZYME_COMMIT = "cbb970161fd41ce55da028f0960a441382b07112" ENZYME_SHA256 = "" http_archive( diff --git a/enzyme_jax/clang_compile.cc b/enzyme_jax/clang_compile.cc index ecf1a281..7f7f1428 100644 --- a/enzyme_jax/clang_compile.cc +++ b/enzyme_jax/clang_compile.cc @@ -620,12 +620,15 @@ struct tensor if (II->getIntrinsicID() == llvm::Intrinsic::dbg_value) continue; } + if (isa(cur)) + continue; if (auto SI = dyn_cast(cur)) { assert(SI->getPointerOperand() == prev); auto C = dyn_cast(SI->getValueOperand()); - if (auto CF = dyn_cast_or_null(C)) + if (C && C->isNullValue()) { + } else if (auto CF = dyn_cast_or_null(C)) { assert(CF->isZero()); - else { + } else { llvm::errs() << "SI: " << *SI << " C: " << *SI->getValueOperand() << "\n"; assert(0); @@ -633,9 +636,12 @@ struct tensor toErase.insert(SI); continue; } - llvm::errs() << " unsupported value to erase:\n"; - llvm::errs() << " cur: " << *cur << " prev: " << *prev << "\n"; - assert(0); + std::string err_str; + llvm::raw_string_ostream ss(err_str); + ss << *mod << "\n"; + ss << " unsupported value to erase:\n"; + ss << " cur: " << *cur << " prev: " << *prev << "\n"; + throw pybind11::value_error(ss.str()); } for (auto I : toErase) { I->eraseFromParent(); diff --git a/enzyme_jax/enzyme_call.cc b/enzyme_jax/enzyme_call.cc index 2330213d..1bfaaecd 100644 --- a/enzyme_jax/enzyme_call.cc +++ b/enzyme_jax/enzyme_call.cc @@ -43,6 +43,8 @@ #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/cpu/cpu_executable.h" +#include "Enzyme/FunctionUtils.h" + enum class ABI { Primal, Forward, Augmented, Reverse, Tape }; enum class Language : int { CPP = 0, LLVM = 1, MHLO = 2 }; @@ -168,8 +170,22 @@ class CpuKernel { F->setName(fn); assert(!F->empty()); for (auto &F2 : *linkMod) - if (!F2.empty()) + if (!F2.empty()) { F2.addFnAttr(llvm::Attribute::AlwaysInline); + // Remove invariant_load if we expect enzyme to cache explicitly all + // data. Otherwise invariant_load allows Enzyme to assume it need + // not cache, and it is illegal for us to pass in nullptr as the + // primal (since it may be needed). + if (mode == ABI::Augmented || mode == ABI::Reverse || + mode == ABI::Tape) { + for (auto &BB : F2) + for (auto &I : BB) + if (auto LI = llvm::dyn_cast(&I)) + if (LI->hasMetadata(llvm::LLVMContext::MD_invariant_load)) + LI->setMetadata(llvm::LLVMContext::MD_invariant_load, + nullptr); + } + } } ss << " extern \"C\" void " << fn << "(void* retval, void* run_options, void* params, void* " @@ -910,6 +926,7 @@ PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmPrinters(); llvm::InitializeAllAsmParsers(); + EnzymeAlwaysInlineDiff.setValue(true); pybind11::enum_(m, "Language") .value("CPP", Language::CPP) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 4a6d690d..617f658f 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -178,3 +178,22 @@ def sumrev(in0): print(primals, grads) assert jnp.abs(primals-50*49/2)<1e-6 assert (jnp.abs(grads[0]-1) <1e-6).all() + +@enzyme_jax_ir() +def ecache(x): + return x * x[0] + +@jax.jit +def cacherev(in0, din0): + primals, f_vjp = jax.vjp(ecache, in0) + grads = f_vjp(din0) + return grads + +dim = 288 + +x = jnp.array(range(dim), dtype=jnp.float32) +dx = jnp.array(range(dim), dtype=jnp.float32) + +grads = cacherev(x, dx) +assert jnp.abs(grads[0][0]-287*288*(2*287+1)/6)<1e-6 +assert (jnp.abs(grads[0][1:]) <1e-6).all() diff --git a/test/llama.py b/test/llama.py index aabc456c..f90aa28a 100644 --- a/test/llama.py +++ b/test/llama.py @@ -252,12 +252,12 @@ def jfunc(x, weights, key_cache, value_cache): def efunc(x, weights, key_cache, value_cache): return func(x, weights, key_cache, value_cache) -eres = efunc(x, weights, key_cache, value_cache) -print("Enzyme primal", eres) -res = func(x, weights, key_cache, value_cache) -print("Jax primal", res) -print (" max error", jnp.max(jnp.abs(eres-res))) -assert (jnp.abs(eres - res) < 1e-3).all() +# eres = efunc(x, weights, key_cache, value_cache) +# print("Enzyme primal", eres) +# res = func(x, weights, key_cache, value_cache) +# print("Jax primal", res) +# print (" max error", jnp.max(jnp.abs(eres-res))) +# assert (jnp.abs(eres - res) < 1e-3).all() #jfunc = jax.jit(partial(forward, config)) # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") @@ -270,11 +270,11 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) -print("pre fwd diff") -eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) -print("Enzyme fwd", eres) -jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) -print("Jax fwd", jres) +# print("pre fwd diff") +# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) +# print("Enzyme fwd", eres) +# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) +# print("Jax fwd", jres) @jax.jit