Skip to content

Commit

Permalink
xu feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Jul 25, 2023
1 parent f8666cd commit 849ac96
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions torchbenchmark/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def __init__(self, test, device, jit=False, batch_size=1, extra_args=[]):
self.model = clip_vit_b32()
self.model.to(self.device)

# Create optimizer
self.loss_fn = ContrastiveLossWithTemperature()
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(loss_fn.parameters()),
lr=5.0e-4,
weight_decay=1.0e-4,
eps=1.0e-6,
)



def get_module(self):
Expand All @@ -46,25 +55,16 @@ def get_module(self):
def train(self):
self.model.train()

# Create optimizer
loss_fn = ContrastiveLossWithTemperature()
optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(loss_fn.parameters()),
lr=5.0e-4,
weight_decay=1.0e-4,
eps=1.0e-6,
)

total_loss = 0
optimizer.zero_grad()
self.optimizer.zero_grad()

# Forward pass
image_embedding, text_embedding = self.model(self.image_tensor, self.text_tensor)

# Backward pass
loss = loss_fn(image_embedding, text_embedding)
loss = self.loss_fn(image_embedding, text_embedding)
loss.backward()
optimizer.step()
self.optimizer.step()

total_loss += loss.item()

Expand Down

0 comments on commit 849ac96

Please sign in to comment.