Skip to content

Commit

Permalink
Remove all torch.no_grad contexts in eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed May 3, 2024
1 parent 3a26c36 commit 3450ca2
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 52 deletions.
3 changes: 1 addition & 2 deletions torchbenchmark/models/DALLE2_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def set_module(self, new_model):

def eval(self):
model, inputs = self.get_module()
with torch.no_grad():
images = model(*inputs)
images = model(*inputs)
return (images,)

def train(self):
Expand Down
3 changes: 1 addition & 2 deletions torchbenchmark/models/cm3leon_generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,5 @@ def train(self):
return loss.item()

def eval(self):
with torch.no_grad():
out = self.model(*self.example_inputs)
out = self.model(*self.example_inputs)
return (out,)
3 changes: 1 addition & 2 deletions torchbenchmark/models/functorch_maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,5 @@ def train(self):
def eval(self) -> Tuple[torch.Tensor]:
model, (example_input,) = self.get_module()
model.eval()
with torch.no_grad():
out = model(example_input)
out = model(example_input)
return (out, )
3 changes: 1 addition & 2 deletions torchbenchmark/models/hf_Whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def get_module(self):

def eval(self):
self.model.eval()
with torch.no_grad():
self.model(self.example_inputs["input_ids"])
self.model(self.example_inputs["input_ids"])

def enable_fp16(self):
self.model.half()
Expand Down
3 changes: 1 addition & 2 deletions torchbenchmark/models/lennard_jones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,5 @@ def train(self):
def eval(self) -> Tuple[torch.Tensor]:
model = self.model
model.eval()
with torch.no_grad():
out = make_prediction(model, self.drs)
out = make_prediction(model, self.drs)
return out
3 changes: 1 addition & 2 deletions torchbenchmark/models/maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,5 @@ def train(self):
def eval(self) -> Tuple[torch.Tensor]:
model, (example_input,) = self.get_module()
model.eval()
with torch.no_grad():
out = model(example_input)
out = model(example_input)
return (out, )
3 changes: 1 addition & 2 deletions torchbenchmark/models/opacus_cifar10/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,5 @@ def eval(self) -> Tuple[torch.Tensor]:
(images, ) = self.example_inputs
model.eval()
targets = self.example_target
with torch.no_grad():
out = model(images)
out = model(images)
return (out, )
3 changes: 1 addition & 2 deletions torchbenchmark/models/pyhpc_equation_of_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,5 @@ def train(self):

def eval(self) -> Tuple[torch.Tensor]:
model, example_inputs = self.get_module()
with torch.no_grad():
out = model(*example_inputs)
out = model(*example_inputs)
return (out, )
3 changes: 1 addition & 2 deletions torchbenchmark/models/pyhpc_isoneutral_mixing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,5 @@ def train(self):

def eval(self) -> Tuple[torch.Tensor]:
model, example_inputs = self.get_module()
with torch.no_grad():
out = model(*example_inputs)
out = model(*example_inputs)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,5 @@ def train(self):

def eval(self) -> Tuple[torch.Tensor]:
model, example_inputs = self.get_module()
with torch.no_grad():
out = model(*example_inputs)
out = model(*example_inputs)
return out
23 changes: 11 additions & 12 deletions torchbenchmark/models/pytorch_unet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,17 @@ def jit_callback(self):
def eval(self) -> Tuple[torch.Tensor]:
torch.backends.cudnn.deterministic = True
self.model.eval()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=self.args.amp):
mask_pred = self.model(self.example_inputs)

if self.model.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
else:
mask_pred = (
F.one_hot(mask_pred.argmax(dim=1), self.model.n_classes)
.permute(0, 3, 1, 2)
.float()
)
with torch.cuda.amp.autocast(enabled=self.args.amp):
mask_pred = self.model(self.example_inputs)

if self.model.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
else:
mask_pred = (
F.one_hot(mask_pred.argmax(dim=1), self.model.n_classes)
.permute(0, 3, 1, 2)
.float()
)
return (mask_pred,)

def _get_args(self):
Expand Down
29 changes: 14 additions & 15 deletions torchbenchmark/models/soft_actor_critic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,21 +258,20 @@ def train(self):

def eval(self) -> Tuple[torch.Tensor]:
niter = 1
with torch.no_grad():
discount = 1.0
episode_return_history = []
for episode in range(niter):
episode_return = 0.0
state, _info = self.test_env.reset()
done, info = False, {}
for step_num in range(self.args.max_episode_steps):
if done:
break
action = self.agent.forward(state)
state, reward, done, info, _unused = self.test_env.step(action)
episode_return += reward * (discount**step_num)
episode_return_history.append(episode_return)
retval = torch.tensor(episode_return_history)
discount = 1.0
episode_return_history = []
for episode in range(niter):
episode_return = 0.0
state, _info = self.test_env.reset()
done, info = False, {}
for step_num in range(self.args.max_episode_steps):
if done:
break
action = self.agent.forward(state)
state, reward, done, info, _unused = self.test_env.step(action)
episode_return += reward * (discount**step_num)
episode_return_history.append(episode_return)
retval = torch.tensor(episode_return_history)
return (torch.tensor(action),)

def get_optimizer(self):
Expand Down
9 changes: 4 additions & 5 deletions torchbenchmark/models/timm_efficientdet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,8 @@ def train(self):
# self.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

def eval(self) -> Tuple[torch.Tensor]:
with torch.no_grad():
for input, target in self.loader:
with self.amp_autocast():
output = self.model(input, img_info=target)
self.evaluator.add_predictions(output, target)
for input, target in self.loader:
with self.amp_autocast():
output = self.model(input, img_info=target)
self.evaluator.add_predictions(output, target)
return (output, )

0 comments on commit 3450ca2

Please sign in to comment.