From 6ace3dbe66491ad062d640fb94e66265894d4910 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 18 Dec 2024 16:19:26 -0800 Subject: [PATCH] turning off compile on loss function [ghstack-poisoned] --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 3b157ad1..8dbe80f5 100644 --- a/train.py +++ b/train.py @@ -130,8 +130,9 @@ def loss_fn(pred, labels): pred.flatten(0, 1).float(), labels.flatten(0, 1) ) - if job_config.training.compile: - loss_fn = torch.compile(loss_fn) + # TODO: compiling loss function causes CUDA errors, turning off for now + # if job_config.training.compile: + # loss_fn = torch.compile(loss_fn) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: