Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eqsat ci #140

Merged
merged 75 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
64a2127
Simplify path infra (#128)
wsmoses Oct 1, 2024
c81918c
Bump internals (#130)
wsmoses Oct 2, 2024
5d4e48d
get out hlomodule from wrapper module
smjleo Oct 3, 2024
dbd9488
construct analysis with test gpu specs
jbachurski Oct 3, 2024
631704b
fix build
smjleo Oct 3, 2024
524a5d1
actually compute non-zero costs
jbachurski Oct 3, 2024
3a1a6d8
use analytical cost model
smjleo Oct 3, 2024
34d7d92
dedup dependency for gpu
smjleo Oct 3, 2024
9c8ecbf
Convolution support
aryavohra Oct 4, 2024
a3646e0
Merge pull request #9 from aryavohra/better-benchmark
aryavohra Oct 4, 2024
d89eb53
adding more multi patterns
aryavohra Oct 4, 2024
2abd9b0
Attempt gpu ci fix (#125)
wsmoses Oct 4, 2024
41ad703
formatting and fix multi pattern rewrite
smjleo Oct 5, 2024
2158a80
add debug outputs
smjleo Oct 5, 2024
a746d4f
try infering device description (broken on cpu, maybe it works on gpu)
smjleo Oct 5, 2024
0c1b224
remove debug output for device desc
smjleo Oct 5, 2024
8932ba0
clean up cost model, and make it depend on platform
smjleo Oct 5, 2024
d192260
Merge pull request #10 from aryavohra/analytical-cost-model-debug
smjleo Oct 5, 2024
953fad8
Adding resnet test and eqsat after jvp
aryavohra Oct 6, 2024
788f426
hugging face transformers vision transformer and resnet added
aryavohra Oct 8, 2024
93655ef
Fast path slice contiguous constant (#137)
wsmoses Oct 8, 2024
414f96b
remove unneeded code in resnet
smjleo Oct 8, 2024
f131d88
fixed conv rewrite
aryavohra Oct 8, 2024
f5d2a6d
fixed next multi-pattern for conv
aryavohra Oct 8, 2024
c870d2f
turn off fusion costs by default
smjleo Oct 8, 2024
1801c12
build python 3.11 instead
smjleo Oct 8, 2024
506f443
add batched matmul multi-pattern rewrite
smjleo Oct 8, 2024
c98d059
merging rewrites
aryavohra Oct 8, 2024
61771e4
Adding back deleted
aryavohra Oct 8, 2024
5341985
bert huggingface added
aryavohra Oct 10, 2024
12397dc
add matmul multi-pattern rewrite with different lhs contracting
smjleo Oct 10, 2024
7e38527
Add jaxmd tests (#136)
wsmoses Oct 11, 2024
25c66f7
fix broken asserts
smjleo Oct 12, 2024
9ffb741
fix contracting 2 rewrite
smjleo Oct 12, 2024
cdf049a
adding mistral test
aryavohra Oct 12, 2024
2f1a703
Transpose batch (#138)
wsmoses Oct 13, 2024
70ec360
ci: touch lock file
smjleo Oct 14, 2024
18b362d
try fixing ci
smjleo Oct 14, 2024
316cce8
turn off visibility checks
smjleo Oct 14, 2024
1433e88
try removing dependency?
smjleo Oct 14, 2024
f4b60c0
try adding redzone allocator
smjleo Oct 14, 2024
bbd5947
changes
smjleo Oct 14, 2024
97b9a12
try adding config cuda
smjleo Oct 14, 2024
1efccc7
bump internals
smjleo Oct 14, 2024
698e331
bump again
smjleo Oct 14, 2024
502b659
merge with upstream
smjleo Oct 14, 2024
41ca0cc
Merge branch 'EnzymeAD-main' into eqsat-ci
smjleo Oct 14, 2024
f02cc35
use python 3.12
smjleo Oct 14, 2024
47dae67
remove crosstool in cuda config
smjleo Oct 14, 2024
01e9dc9
Maxtext (#139)
wsmoses Oct 14, 2024
450096e
don't build rocm
smjleo Oct 14, 2024
b2e3c86
bring back crosstool
smjleo Oct 14, 2024
bf9be09
add a bunch of bazel flags
smjleo Oct 14, 2024
67e6be1
turn off tf nvcc clang
smjleo Oct 14, 2024
fe521fd
device description fix
smjleo Oct 15, 2024
cb262de
update cuda/cudnn versions
smjleo Oct 15, 2024
751e525
try printing libs
smjleo Oct 15, 2024
221c7d8
revert
smjleo Oct 15, 2024
a0393c0
does nvrtc work as deps?
smjleo Oct 15, 2024
cb6156d
expose multi pattern rules
smjleo Oct 15, 2024
2042008
put python in path
smjleo Oct 15, 2024
de1cf6e
add ortools as dep
smjleo Oct 15, 2024
570118d
add eqsat env var for test
smjleo Oct 15, 2024
c886488
ortools where are you???
smjleo Oct 15, 2024
5bf3787
descent into madness
smjleo Oct 15, 2024
b529739
please
smjleo Oct 15, 2024
37303d1
Update StableHLOAutoDiffOpInterfaceImpl.cpp
wsmoses Oct 16, 2024
cd72ac4
a
smjleo Oct 16, 2024
4043130
cleanup, and hopefully make it build on cyclops again
smjleo Oct 16, 2024
c395ec5
get results csv from the right place
smjleo Oct 16, 2024
8a35dfd
maxtext
smjleo Oct 17, 2024
284f9c6
Merge branch 'EnzymeAD-main' into eqsat-ci
smjleo Oct 17, 2024
0f5c0ae
add eqsat to maxtext test
smjleo Oct 17, 2024
def8a25
csv location
smjleo Oct 17, 2024
b8df1f6
get csv
smjleo Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build --announce_rc

query --experimental_repo_remote_exec
build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
Expand All @@ -11,6 +12,7 @@ build --define framework_shared_object=true
build --define tsl_protobuf_header_only=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --check_visibility=false

build -c opt

Expand All @@ -27,4 +29,22 @@ build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
build:cuda --@local_config_cuda//:cuda_compiler=nvcc

build:cudaci --repo_env TF_NEED_CUDA=1
build:cudaci --repo_env TF_NVCC_CLANG=0
build:cudaci --repo_env CUDA_NVCC=1
build:cudaci --repo_env CC="/usr/bin/gcc"
build:cudaci --repo_env TF_NCCL_USE_STUB=1
build:cudaci --repo_env=HERMETIC_CUDA_VERSION="12.6.1"
build:cudaci --repo_env=HERMETIC_CUDNN_VERSION="9.4.0"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cudaci --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cudaci --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:cudaci --@local_config_cuda//:enable_cuda
build:cudaci --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cudaci --@local_config_cuda//cuda:include_cuda_libs=true
build:cudaci --@local_config_cuda//:cuda_compiler=nvcc

72 changes: 66 additions & 6 deletions .buildkite/gpu_pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,37 +1,97 @@
steps:
- label: "CUDA"
agents:
queue: "juliagpu"
queue: "benchmark"
gpu: "rtx4070"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 180
commands: |
pwd
env
echo "--- Setup :python: Dependencies"
mkdir -p .local/bin
export PATH="`pwd`/.local/bin:`pwd`/conda/bin:\$PATH"
echo "openssl md5 | cut -d' ' -f2" > .local/bin/md5
chmod +x .local/bin/md5

# No one tells us what to do
unset NV_LIBCUBLAS_VERSION
unset NVIDIA_VISIBLE_DEVICES
unset NV_NVML_DEV_VERSION
unset NV_LIBNCCL_DEV_PACKAGE
unset NV_LIBNCCL_DEV_PACKAGE_VERSION
unset NVIDIA_REQUIRE_CUDA
unset NV_LIBCUBLAS_DEV_PACKAGE
unset NV_NVTX_VERSION

curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-amd64

mv bazel* .local/bin/bazel
chmod +x .local/bin/bazel
export PATH="`pwd`/.local/bin:\$PATH"

export RUSTUP_HOME=`pwd`/.local/.rustup
export CARGO_HOME=`pwd`/.local/.cargo bash
export CARGO_HOME=`pwd`/.local/.cargo bash

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
export PATH="`pwd`/.local/.cargo/bin:\$PATH"
rustup default stable
cargo install cxxbridge-cmd

mkdir -p .baztmp
touch src/enzyme_ad/jax/deps/tensat/cargo-bazel-lock.json

echo "--- :python: Test"

# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --test_output=errors //test/...
# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no //test:bench_vs_xla
CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no --config=cuda //test:llama
export CUDA_DIR=`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvcc_cu12/site-packages/nvidia/cuda_nvcc
export XLA_FLAGS=--xla_gpu_cuda_data_dir=\$CUDA_DIR
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cusolver_cu12/site-packages/nvidia/cusolver:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cudnn_cu12/site-packages/nvidia/cudnn/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/test.runfiles/pypi_nvidia_cublas_cu12/site-packages/nvidia/cublas/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_cupti_cu12/site-packages/nvidia/cuda_cupti/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_runtime_cu12/site-packages/nvidia/cuda_runtime/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvrtc_cu12/site-packages/nvidia/cuda_nvrtc/lib:\$LD_LIBRARY_PATH"
export PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvcc_cu12/site-packages/nvidia/cuda_nvcc/bin:\$PATH"
export PATH="`pwd`/bazel-bin/test/llama.runfiles/python_x86_64-unknown-linux-gnu/bin:\$PATH"
export TF_CPP_MIN_LOG_LEVEL=0

BAZEL_BUILD_FLAGS=()
BAZEL_BUILD_FLAGS+=(--define=no_aws_support=true)
BAZEL_BUILD_FLAGS+=(--define=no_gcp_support=true)
BAZEL_BUILD_FLAGS+=(--define=no_hdfs_support=true)
BAZEL_BUILD_FLAGS+=(--define=no_kafka_support=true)
BAZEL_BUILD_FLAGS+=(--define=no_ignite_support=true)
BAZEL_BUILD_FLAGS+=(--define=grpc_no_ares=true)

BAZEL_BUILD_FLAGS+=(--define=llvm_enable_zlib=false)

BAZEL_BUILD_FLAGS+=(--verbose_failures)
BAZEL_BUILD_FLAGS+=(--cxxopt=-std=c++17 --host_cxxopt=-std=c++17)
BAZEL_BUILD_FLAGS+=(--cxxopt=-DTCP_USER_TIMEOUT=0)
BAZEL_BUILD_FLAGS+=(--check_visibility=false)
BAZEL_BUILD_FLAGS+=(--experimental_cc_shared_library)

mkdir .eqsat-tmp
export TMPDIR="`pwd`/.eqsat-tmp"
export TMP=$TMPDIR
export TEMP=$TMPDIR
BAZEL_BUILD_FLAGS+=(--action_env=TMP=\$TMPDIR --action_env=TEMP=\$TMPDIR --action_env=TMPDIR=\$TMPDIR --sandbox_tmpfs_path=\$TMPDIR)
BAZEL_BUILD_FLAGS+=(--action_env=EQSAT_PLATFORM=gpu)
BAZEL_BUILD_FLAGS+=(--config=cudaci)

export EQSAT_PLATFORM=gpu

CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp run --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL //builddeps:requirements.update || echo "no req update"
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --test_output=errors //test/... || echo "fail1"
find `pwd`/bazel-bin/test/llama.runfiles > finds.txt
# HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:bench_vs_xla || echo "fail2"
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s \${BAZEL_BUILD_FLAGS[@]} //test:llama || echo "fail3"
HERMETIC_PYTHON_VERSION="3.12" bazel-bin/test/llama
cat bazel-out/*/testlogs/test/llama/test.log
artifact_paths:
- "finds.txt"
- "finds2.txt"
- "bazel-out/*/testlogs/test/llama/test.log"
- "bazel-out/*/testlogs/test/llama/bench_vs_xla.log"
- "bazel-bin/test/llama.runfiles/results_*.csv"
3 changes: 2 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ steps:
os:
- macos
python:
- "3.12"
- "3.11"
agents:
queue: "juliaecosystem"
os: "{{matrix.os}}"
Expand Down Expand Up @@ -55,6 +55,7 @@ steps:
cargo install cxxbridge-cmd

mkdir -p .baztmp
touch src/enzyme_ad/jax/deps/tensat/cargo-bazel-lock.json
# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --test_output=errors //test/...
# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no //test:bench_vs_xla
CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no //test:llama
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pyc
*.swp
*.swo
bazel-*
Expand Down
3 changes: 1 addition & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ http_archive(
)

load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "7d4f8d1e8a91e67a713ac69796a22f343d292327"
http_archive(
name = "xla",
#sha256 = XLA_SHA256,
sha256 = XLA_SHA256,
strip_prefix = "xla-" + XLA_COMMIT,
urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_cmds = XLA_PATCHES,
Expand Down
4 changes: 2 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ export CUDA_HOME=$HOME/miniconda3/
export PATH=$PATH:$CUDA_HOME/bin
export CUDACXX=$CUDA_HOME/bin/nvcc
BAZEL_BUILD_FLAGS+=(--config=cuda)
HERMETIC_PYTHON_VERSION=3.12 bazel build ${BAZEL_BUILD_FLAGS[@]} :wheel
pip install bazel-bin/enzyme_ad-0.0.8-py312-none-manylinux2014_x86_64.whl --no-deps --force-reinstall
HERMETIC_PYTHON_VERSION=3.11 bazel build ${BAZEL_BUILD_FLAGS[@]} :wheel
pip install bazel-bin/enzyme_ad-0.0.8-py311-none-manylinux2014_x86_64.whl --no-deps --force-reinstall

3 changes: 3 additions & 0 deletions builddeps/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ compile_pip_requirements(
"--build-isolation",
"--rebuild",
],
extra_deps = [
# "@pypi_wheel//:pkg"
],
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = True,
Expand Down
5 changes: 3 additions & 2 deletions builddeps/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
-r test-requirements.txt

jax >= 0.4.21
jaxlib >= 0.4.21
jax
jaxlib
absl_py >= 2.0.0
ortools >= 9.10
Loading
Loading