Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 11, 2024
1 parent 7d5ddb3 commit 4d4859f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
4 changes: 0 additions & 4 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,12 @@ def _enzyme_primal_lowering(
for f in mod.regions[0].blocks[0]:
fns.append(f.sym_name.value)

print("pass_pipeline:\n", pass_pipeline)
print("source:\n", source)
name, nmod = enzyme_call.run_pass_pipeline(fns, source, pass_pipeline)
if print_mlir:
if type(print_mlir) != type(True):
print_mlir.write(nmod)
else:
print(str(nmod), flush=True)
print("post pass_pipeline:\n", pass_pipeline)
print("post source:\n", source)
nmod = ir.Module.parse(nmod)
fn = None
pushtop = []
Expand Down
6 changes: 3 additions & 3 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def forward(position):
self.fn = forward
self.name = "scatter_sum"

self.ins = [[2.0, 4.0, 6.0, 8.0]]
self.dins = [2.7, 3.1, 5.9, 4.2]
self.douts = [x.copy() for x in self.dins]
self.ins = [jnp.array([2.0, 4.0, 6.0, 8.0])]
self.dins = [jnp.array([2.7, 3.1, 5.9, 4.2])]
self.douts = self.dins
self.AllPipelines = pipelines
# No support for stablehlo.scatter atm
self.mlirad_rev = False
Expand Down

0 comments on commit 4d4859f

Please sign in to comment.