diff --git a/torchbenchmark/models/clip/__init__.py b/torchbenchmark/models/clip/__init__.py index 67e5ad4002..9f6a35f765 100644 --- a/torchbenchmark/models/clip/__init__.py +++ b/torchbenchmark/models/clip/__init__.py @@ -37,6 +37,15 @@ def __init__(self, test, device, jit=False, batch_size=1, extra_args=[]): self.model = clip_vit_b32() self.model.to(self.device) + # Create optimizer + self.loss_fn = ContrastiveLossWithTemperature() + self.optimizer = torch.optim.AdamW( + list(self.model.parameters()) + list(loss_fn.parameters()), + lr=5.0e-4, + weight_decay=1.0e-4, + eps=1.0e-6, + ) + def get_module(self): @@ -46,25 +55,16 @@ def get_module(self): def train(self): self.model.train() - # Create optimizer - loss_fn = ContrastiveLossWithTemperature() - optimizer = torch.optim.AdamW( - list(self.model.parameters()) + list(loss_fn.parameters()), - lr=5.0e-4, - weight_decay=1.0e-4, - eps=1.0e-6, - ) - total_loss = 0 - optimizer.zero_grad() + self.optimizer.zero_grad() # Forward pass image_embedding, text_embedding = self.model(self.image_tensor, self.text_tensor) # Backward pass - loss = loss_fn(image_embedding, text_embedding) + loss = self.loss_fn(image_embedding, text_embedding) loss.backward() - optimizer.step() + self.optimizer.step() total_loss += loss.item()