Skip to content

Commit

Permalink
Update on "[BE] restructure tests and assets folders"
Browse files Browse the repository at this point in the history
1. now `tests` folder has two sub-tasks `integration_tests.py` and `unit_tests` folder
2. moved `version.txt` and `license_header.txt` to `assets` folder

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Dec 17, 2024
2 parents b56630a + 54dd2a1 commit 07d98b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
23 changes: 11 additions & 12 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,17 @@ def build_test_list():
"1D compile",
"1d_compile",
),
# TODO: temporarily disabling to let CI be able to run other tests
# OverrideDefinitions(
# [
# [
# "--training.compile",
# "--activation_checkpoint.mode selective",
# "--activation_checkpoint.selective_ac_option op",
# ],
# ],
# "1D compile with selective op AC",
# "1d_compile_sac_op",
# ),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,15 @@ def loss_fn(pred, labels):
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
m.init_weights(buffer_device=buffer_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
model.init_weights(buffer_device=buffer_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
model.train()

model_parts = [model]
Expand Down

0 comments on commit 07d98b5

Please sign in to comment.