Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add neuralgcm #141

Merged
merged 17 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions .buildkite/gpu_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,18 @@ steps:
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --test_output=errors //test/... || echo "fail1"
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:bench_vs_xla
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:llama
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:jaxmd

if [ -f "test/test_neuralgcm.py" ]; then
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:test_neuralgcm
fi
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:jaxmd
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:neuralgcm_test
bazel-bin/test/llama.runfiles/python_*/bin/python3 -m pip install bazel-bin/*.whl https://github.com/wsmoses/maxtext aqtp tensorboardX google-cloud-storage datasets gcsfs
bazel-bin/test/llama.runfiles/python_*/bin/python3 test/maxtext.py > maxtext.log
cat bazel-out/*/testlogs/test/llama/test.log
cat bazel-out/*/testlogs/test/bench_vs_xla/test.log
cat bazel-out/*/testlogs/test/jaxmd/test.log
if [ -f "test/test_neuralgcm.py" ]; then
cat bazel-out/*/testlogs/test/test_neuralgcm/test.log
fi
cat bazel-out/*/testlogs/test/neuralgcm_test/test.log
cat maxtext.log
artifact_paths:
- "bazel-out/*/testlogs/test/llama/test.log"
- "bazel-out/*/testlogs/test/bench_vs_xla/test.log"
- "bazel-out/*/testlogs/test/jaxmd/test.log"
- "bazel-out/*/testlogs/test/test_neuralgcm/test.log"
- "bazel-out/*/testlogs/test/neuralgcm_test/test.log"
- "maxtext.log"
879 changes: 870 additions & 9 deletions builddeps/requirements_lock_3_10.txt

Large diffs are not rendered by default.

871 changes: 864 additions & 7 deletions builddeps/requirements_lock_3_11.txt

Large diffs are not rendered by default.

874 changes: 865 additions & 9 deletions builddeps/requirements_lock_3_12.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions builddeps/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ https://github.com/wsmoses/jraph/archive/b00d9a03db76c69a258a86df81638b9a2f28829
# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz

jax-cuda12-plugin[with_cuda]; sys_platform == 'linux'
neuralgcm
gcsfs
requests; sys_platform == 'linux'

# -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Expand Down
12 changes: 12 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ py_test(
timeout='long'
)

py_test(
name = "neuralgcm_test",
srcs = [
"neuralgcm_test.py",
"test_utils.py",
],
imports = ["."],
deps = TEST_DEPS + ["@pypi_neuralgcm//:pkg", "@pypi_gcsfs//:pkg"],
timeout='eternal'
)


# py_test(
# name = "maxtext",
# srcs = [
Expand Down
32 changes: 14 additions & 18 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def setUp(self):
jnp.array([0.1, 0.2, 0.3]),
jnp.array([50.0, 70.0, 110.0]),
]
self.douts = [jnp.array([500.0, 700.0, 110.0])]
self.douts = jnp.array([500.0, 700.0, 110.0])

def add_one(x, y):
return x + 1 + y
Expand All @@ -42,7 +42,7 @@ def setUp(self):
jnp.array([50.0, 70.0, 110.0]),
jnp.array([1300.0, 1700.0, 1900.0]),
]
self.douts = [jnp.array([500.0, 700.0, 110.0])]
self.douts = jnp.array([500.0, 700.0, 110.0])

def add_two(x, z, y):
return x + y
Expand All @@ -55,7 +55,7 @@ class Sum(EnzymeJaxTest):
def setUp(self):
self.ins = [jnp.array(range(50), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(50)], dtype=jnp.float32)]
self.douts = [1.0]
self.douts = jnp.array(1.0)

def sum(x):
return jnp.sum(x)
Expand All @@ -69,7 +69,7 @@ def setUp(self):
dim = 288
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = jnp.array([i * i for i in range(dim)], dtype=jnp.float32)

self.primfilter = no_newxla
self.fwdfilter = no_newxla
Expand All @@ -89,7 +89,7 @@ def setUp(self):
self.dins = [
jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)
]
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = jnp.array([i * i for i in range(dim)], dtype=jnp.float32)

self.primfilter = no_newxla
self.fwdfilter = no_newxla
Expand All @@ -107,11 +107,9 @@ def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]
self.douts = jnp.array(
[i * i for i in range(2 * dim)], dtype=jnp.float32
).reshape((2, dim))

self.primfilter = no_newxla
self.fwdfilter = no_newxla
Expand All @@ -134,11 +132,9 @@ def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]
self.douts = jnp.array(
[i * i for i in range(2 * dim)], dtype=jnp.float32
).reshape((2, dim))

self.primfilter = no_newxla
self.fwdfilter = no_newxla
Expand Down Expand Up @@ -174,7 +170,7 @@ def setUp(self):
jnp.array([i * i for i in range(dim)], dtype=jnp.float32),
jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32),
]
self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)]
self.douts = jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)

self.revfilter = justjax
# self.revfilter = nomlir
Expand Down Expand Up @@ -274,7 +270,7 @@ def forward(c_tau):
jnp.array([2.7, 2.7, 2.7]),
]
self.dins = [jnp.array([3.1, 3.1, 3.1])]
self.douts = self.dins
self.douts = (self.dins[0],)
self.revfilter = lambda _: []
# No support for stablehlo.while atm
# self.revfilter = justjax
Expand All @@ -298,7 +294,7 @@ def forward(position):

self.ins = [jnp.array([2.0, 4.0, 6.0, 8.0])]
self.dins = [jnp.array([2.7, 3.1, 5.9, 4.2])]
self.douts = []
self.douts = self.fn(*self.ins)
self.revfilter = lambda _: []
# No support for stablehlo.scatter atm
self.mlirad_rev = False
Expand Down
107 changes: 1 addition & 106 deletions test/jaxmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,111 +24,6 @@

jax.config.update("jax_enable_x64", True)

partialopt = (
"inline{default-pipeline=canonicalize max-iterations=4},"
+ """canonicalize,cse,
enzyme-hlo-generate-td{
patterns=compare_op_canon<16>;
transpose_transpose<16>;
broadcast_in_dim_op_canon<16>;
convert_op_canon<16>;
dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;
chained_dynamic_broadcast_in_dim_canonicalization<16>;
dynamic_broadcast_in_dim_all_dims_non_expanding<16>;
noop_reduce_op_canon<16>;
empty_reduce_op_canon<16>;
dynamic_reshape_op_canon<16>;
get_tuple_element_op_canon<16>;
real_op_canon<16>;
imag_op_canon<16>;
get_dimension_size_op_canon<16>;
gather_op_canon<16>;
reshape_op_canon<16>;
merge_consecutive_reshapes<16>;
transpose_is_reshape<16>;
zero_extent_tensor_canon<16>;
reorder_elementwise_and_shape_op<16>;

cse_broadcast_in_dim<16>;
cse_slice<16>;
cse_transpose<16>;
cse_convert<16>;
cse_pad<16>;
cse_dot_general<16>;
cse_reshape<16>;
cse_mul<16>;
cse_div<16>;
cse_add<16>;
cse_subtract<16>;
cse_min<16>;
cse_max<16>;
cse_neg<16>;
cse_concatenate<16>;

concatenate_op_canon<16>(1024);
select_op_canon<16>(1024);
add_simplify<16>;
sub_simplify<16>;
and_simplify<16>;
max_simplify<16>;
min_simplify<16>;
or_simplify<16>;
negate_simplify<16>;
mul_simplify<16>;
div_simplify<16>;
rem_simplify<16>;
pow_simplify<16>;
sqrt_simplify<16>;
cos_simplify<16>;
sin_simplify<16>;
noop_slice<16>;
const_prop_through_barrier<16>;
slice_slice<16>;
shift_right_logical_simplify<16>;
pad_simplify<16>;
negative_pad_to_slice<16>;
tanh_simplify<16>;
exp_simplify<16>;
slice_simplify<16>;
convert_simplify<16>;
dynamic_slice_to_static<16>;
dynamic_update_slice_elim<16>;
concat_to_broadcast<16>;
reduce_to_reshape<16>;
broadcast_to_reshape<16>;
gather_simplify<16>;
iota_simplify<16>(1024);
broadcast_in_dim_simplify<16>(1024);
convert_concat<1>;
dynamic_update_to_concat<1>;
slice_of_dynamic_update<1>;
slice_elementwise<1>;
slice_pad<1>;
dot_reshape_dot<1>;
concat_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
concat_push_binop_add<1>;
concat_push_binop_mul<1>;
scatter_to_dynamic_update_slice<1>;
reduce_concat<1>;
slice_concat<1>;

bin_broadcast_splat_add<1>;
bin_broadcast_splat_subtract<1>;
bin_broadcast_splat_div<1>;
bin_broadcast_splat_mul<1>;
slice_reshape<1>;

dot_reshape_pad<1>;
pad_dot_general<1>(1);
broadcast_reduce<1>;
},
transform-interpreter,
enzyme-hlo-remove-transform,cse"""
)

pipelines = [
("JaX ", None, CurBackends),
("JaXPipe", JaXPipeline(), CurBackends),
Expand Down Expand Up @@ -245,7 +140,7 @@ def forward(
# for i, v in enumerate(self.ins):
# print("i=", i, v)
self.dins = [x.copy() for x in self.ins]
self.douts = [x.copy() for x in self.ins]
self.douts = tuple(x.copy() for x in self.ins)
self.AllPipelines = pipelines
# No support for stablehlo.while atm
# self.revfilter = justjax
Expand Down
107 changes: 1 addition & 106 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,111 +248,6 @@ def forward(x, config, weights, key_cache, value_cache):
return x


partialopt = (
"inline{default-pipeline=canonicalize max-iterations=4},"
+ """canonicalize,cse,
enzyme-hlo-generate-td{
patterns=compare_op_canon<16>;
transpose_transpose<16>;
broadcast_in_dim_op_canon<16>;
convert_op_canon<16>;
dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;
chained_dynamic_broadcast_in_dim_canonicalization<16>;
dynamic_broadcast_in_dim_all_dims_non_expanding<16>;
noop_reduce_op_canon<16>;
empty_reduce_op_canon<16>;
dynamic_reshape_op_canon<16>;
get_tuple_element_op_canon<16>;
real_op_canon<16>;
imag_op_canon<16>;
get_dimension_size_op_canon<16>;
gather_op_canon<16>;
reshape_op_canon<16>;
merge_consecutive_reshapes<16>;
transpose_is_reshape<16>;
zero_extent_tensor_canon<16>;
reorder_elementwise_and_shape_op<16>;

cse_broadcast_in_dim<16>;
cse_slice<16>;
cse_transpose<16>;
cse_convert<16>;
cse_pad<16>;
cse_dot_general<16>;
cse_reshape<16>;
cse_mul<16>;
cse_div<16>;
cse_add<16>;
cse_subtract<16>;
cse_min<16>;
cse_max<16>;
cse_neg<16>;
cse_concatenate<16>;

concatenate_op_canon<16>(1024);
select_op_canon<16>(1024);
add_simplify<16>;
sub_simplify<16>;
and_simplify<16>;
max_simplify<16>;
min_simplify<16>;
or_simplify<16>;
negate_simplify<16>;
mul_simplify<16>;
div_simplify<16>;
rem_simplify<16>;
pow_simplify<16>;
sqrt_simplify<16>;
cos_simplify<16>;
sin_simplify<16>;
noop_slice<16>;
const_prop_through_barrier<16>;
slice_slice<16>;
shift_right_logical_simplify<16>;
pad_simplify<16>;
negative_pad_to_slice<16>;
tanh_simplify<16>;
exp_simplify<16>;
slice_simplify<16>;
convert_simplify<16>;
dynamic_slice_to_static<16>;
dynamic_update_slice_elim<16>;
concat_to_broadcast<16>;
reduce_to_reshape<16>;
broadcast_to_reshape<16>;
gather_simplify<16>;
iota_simplify<16>(1024);
broadcast_in_dim_simplify<16>(1024);
convert_concat<1>;
dynamic_update_to_concat<1>;
slice_of_dynamic_update<1>;
slice_elementwise<1>;
slice_pad<1>;
dot_reshape_dot<1>;
concat_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
concat_push_binop_add<1>;
concat_push_binop_mul<1>;
scatter_to_dynamic_update_slice<1>;
reduce_concat<1>;
slice_concat<1>;

bin_broadcast_splat_add<1>;
bin_broadcast_splat_subtract<1>;
bin_broadcast_splat_div<1>;
bin_broadcast_splat_mul<1>;
slice_reshape<1>;

dot_reshape_pad<1>;
pad_dot_general<1>(1);
broadcast_reduce<1>;
},
transform-interpreter,
enzyme-hlo-remove-transform,cse"""
)

pipelines = [
("JaX ", None, CurBackends),
("JaXPipe", JaXPipeline(), CurBackends),
Expand Down Expand Up @@ -442,7 +337,7 @@ def sfn(x, weights, key_cache, value_cache):

self.ins = [x, weights, key_cache, value_cache]
self.dins = [dx, weights, key_cache, value_cache]
self.douts = [dx]
self.douts = dx
self.tol = 5e-5


Expand Down
Loading
Loading