Skip to content

Commit

Permalink
push
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Sep 26, 2023
1 parent 356c7d2 commit 17ddf6b
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def torch_clone(x):

def clone_inputs(example_inputs):
import torch
if type(example_inputs) is dict:
if isinstance(example_inputs, Mapping):
res = dict(example_inputs)
for key, value in res.items():
assert isinstance(value, torch.Tensor)
Expand Down Expand Up @@ -355,6 +355,8 @@ def optimizer_step(optimizer):
def forward_pass(mod, inputs, contexts, _collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
with nested(*contexts):
print("====")
print(type(cloned_inputs))
if isinstance(cloned_inputs, Mapping):
return mod(**inputs)
else:
Expand Down Expand Up @@ -481,21 +483,21 @@ def maybe_cast(tbmodel, model, example_inputs):
with pick_grad(name, is_training):
# Get results of native pytorch
reset_rng_state()
try:
model_copy = deepcopy_model(model, is_deepcopy)
optimizer = init_optimizer(name, current_device, model_copy.parameters(), is_training)
correct_result = run_n_iterations(
model_copy, clone_inputs(example_inputs), contexts, optimizer, is_training
)
except Exception as e:
accuracy_status = (
"eager_1st_run_OOM"
if isinstance(e, torch.cuda.OutOfMemoryError)
else "eager_1st_run_fail"
)
print(e)
log.exception(e)
return accuracy_status
# try:
model_copy = deepcopy_model(model, is_deepcopy)
optimizer = init_optimizer(name, current_device, model_copy.parameters(), is_training)
correct_result = run_n_iterations(
model_copy, clone_inputs(example_inputs), contexts, optimizer, is_training
)
# except Exception as e:
# accuracy_status = (
# "eager_1st_run_OOM"
# if isinstance(e, torch.cuda.OutOfMemoryError)
# else "eager_1st_run_fail"
# )
# print(e)
# log.exception(e)
return accuracy_status

# Rerun native pytorch
reset_rng_state()
Expand Down

0 comments on commit 17ddf6b

Please sign in to comment.