Skip to content

Commit

Permalink
Maxtext (#139)
Browse files Browse the repository at this point in the history
* Maxtext

* fixup

* fix

* try fix

* fix

* fix

* cleanup

* Update test-requirements.txt

* cleanup
  • Loading branch information
wsmoses authored Oct 14, 2024
1 parent 2f1a703 commit 01e9dc9
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 5 deletions.
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build --announce_rc

query --experimental_repo_remote_exec
build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
Expand Down
3 changes: 3 additions & 0 deletions builddeps/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ compile_pip_requirements(
"--build-isolation",
"--rebuild",
],
extra_deps = [
# "@pypi_wheel//:pkg"
],
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = True,
Expand Down
4 changes: 2 additions & 2 deletions builddeps/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
#
-r test-requirements.txt

jax >= 0.4.21
jaxlib >= 0.4.21
jax
jaxlib
absl_py >= 2.0.0
4 changes: 3 additions & 1 deletion builddeps/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ absl-py
jax
numpy
jaxlib
https://github.com/wsmoses/jax-md/archive/1188490610b95023f8a51166c3f6b92da31e78fe.tar.gz
https://github.com/wsmoses/jax-md/archive/45059b8f63dad0b5cb171feafff71b82162487e7.tar.gz
# maxtext can't be installed concurrently, but installing it fixes
# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz
jax[cuda12_pip]; sys_platform == 'linux'
requests; sys_platform == 'linux'
# -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Expand Down
10 changes: 10 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,13 @@ py_test(
deps = TEST_DEPS + ["@pypi_jax_md//:pkg"],
timeout='long'
)

# py_test(
# name = "maxtext",
# srcs = [
# "maxtext.py",
# ],
# imports = ["."],
# deps = TEST_DEPS + ["@pypi_maxtext//:pkg"],
# timeout='long'
# )
2 changes: 0 additions & 2 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def f(x, y):

class ConstScatter(EnzymeJaxTest):
def setUp(self):

def forward(c_tau):
Q = c_tau
Q = Q.at[0].multiply(3)
Expand All @@ -284,7 +283,6 @@ def forward(c_tau):

class ScatterSum(EnzymeJaxTest):
def setUp(self):

def energy_fn(R, neighbor):
dR = R[neighbor[0]]
return jnp.sum(jnp.sin(dR))
Expand Down
186 changes: 186 additions & 0 deletions test/maxtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Steps for getting results here
# Run:
# 1) pip install https://github.com/wsmoses/maxtext
# 2) bazel build -c opt //:wheel
# 3) pip install ./bazel-bin/*whl
# 4) python test/maxtext.py

from absl.testing import absltest
import jax.numpy as jnp
import jax.random
import jax.lax
import enzyme_ad.jax as enzyme_jax
from enzyme_ad.jax import (
enzyme_jax_ir,
NewXLAPipeline,
OldXLAPipeline,
JaXPipeline,
hlo_opts,
)
import numpy as np
import timeit

argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")

import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax

partialopt = (
"inline{default-pipeline=canonicalize max-iterations=4},"
+ """canonicalize,cse,
enzyme-hlo-generate-td{
patterns=compare_op_canon<16>;
transpose_transpose<16>;
broadcast_in_dim_op_canon<16>;
convert_op_canon<16>;
dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;
chained_dynamic_broadcast_in_dim_canonicalization<16>;
dynamic_broadcast_in_dim_all_dims_non_expanding<16>;
noop_reduce_op_canon<16>;
empty_reduce_op_canon<16>;
dynamic_reshape_op_canon<16>;
get_tuple_element_op_canon<16>;
real_op_canon<16>;
imag_op_canon<16>;
get_dimension_size_op_canon<16>;
gather_op_canon<16>;
reshape_op_canon<16>;
merge_consecutive_reshapes<16>;
transpose_is_reshape<16>;
zero_extent_tensor_canon<16>;
reorder_elementwise_and_shape_op<16>;
cse_broadcast_in_dim<16>;
cse_slice<16>;
cse_transpose<16>;
cse_convert<16>;
cse_pad<16>;
cse_dot_general<16>;
cse_reshape<16>;
cse_mul<16>;
cse_div<16>;
cse_add<16>;
cse_subtract<16>;
cse_min<16>;
cse_max<16>;
cse_neg<16>;
cse_concatenate<16>;
concatenate_op_canon<16>(1024);
select_op_canon<16>(1024);
add_simplify<16>;
sub_simplify<16>;
and_simplify<16>;
max_simplify<16>;
min_simplify<16>;
or_simplify<16>;
negate_simplify<16>;
mul_simplify<16>;
div_simplify<16>;
rem_simplify<16>;
pow_simplify<16>;
sqrt_simplify<16>;
cos_simplify<16>;
sin_simplify<16>;
noop_slice<16>;
const_prop_through_barrier<16>;
slice_slice<16>;
shift_right_logical_simplify<16>;
pad_simplify<16>;
negative_pad_to_slice<16>;
tanh_simplify<16>;
exp_simplify<16>;
slice_simplify<16>;
convert_simplify<16>;
dynamic_slice_to_static<16>;
dynamic_update_slice_elim<16>;
concat_to_broadcast<16>;
reduce_to_reshape<16>;
broadcast_to_reshape<16>;
gather_simplify<16>;
iota_simplify<16>(1024);
broadcast_in_dim_simplify<16>(1024);
convert_concat<1>;
dynamic_update_to_concat<1>;
slice_of_dynamic_update<1>;
slice_elementwise<1>;
slice_pad<1>;
dot_reshape_dot<1>;
concat_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
concat_push_binop_add<1>;
concat_push_binop_mul<1>;
scatter_to_dynamic_update_slice<1>;
reduce_concat<1>;
slice_concat<1>;
bin_broadcast_splat_add<1>;
bin_broadcast_splat_subtract<1>;
bin_broadcast_splat_div<1>;
bin_broadcast_splat_mul<1>;
slice_reshape<1>;
dot_reshape_pad<1>;
pad_dot_general<1>(1);
broadcast_reduce<1>;
},
transform-interpreter,
enzyme-hlo-remove-transform,cse"""
)

pipelines = [
("JaX ", None),
("JaXPipe", JaXPipeline()),
(
"HLOOpt",
JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
),
),
("PartOpt", JaXPipeline(partialopt)),
("DefOpt", JaXPipeline(hlo_opts())),
]


class MaxText(absltest.TestCase):
def setUp(self):
import MaxText
import MaxText.pyconfig

MaxText.pyconfig.initialize(
[
None,
"test/maxtext_configs/base.yml",
"dataset_type=synthetic",
"steps=10",
]
)

def test(self):
import MaxText
import MaxText.pyconfig
import MaxText.train

config = MaxText.pyconfig.config

for name, pipeline in pipelines:
print("name=", name)

def rewrite(fn):
if pipeline is None:
return fn
else:
return enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(fn)

res1 = MaxText.train.train_loop(config, prejit=rewrite)
print("name=", name, res1)


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 01e9dc9

Please sign in to comment.