Skip to content

Commit

Permalink
turning off compile on loss function
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Dec 19, 2024
1 parent 6274377 commit 6ace3db
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def loss_fn(pred, labels):
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

if job_config.training.compile:
loss_fn = torch.compile(loss_fn)
# TODO: compiling loss function causes CUDA errors, turning off for now
# if job_config.training.compile:
# loss_fn = torch.compile(loss_fn)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
Expand Down

0 comments on commit 6ace3db

Please sign in to comment.