Skip to content

Commit

Permalink
Log load_model failures in csv (#114784)
Browse files Browse the repository at this point in the history
Summary:
Right now when load_model fails (either because of loading error or validation eager run failure), the result won't be logged in generated csv files. Let's log them in csv so that they are monitored by the expected results checking.

X-link: pytorch/pytorch#114784
Approved by: https://github.com/malfet

Reviewed By: atalman

Differential Revision: D51939241

Pulled By: desertfire

fbshipit-source-id: 945e297409654654d2b97aa518bbf1f894c41c8b
  • Loading branch information
desertfire authored and facebook-github-bot committed Dec 8, 2023
1 parent 046478e commit bf8df16
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,8 +1939,7 @@ def validate_model(self, model, example_inputs):
try:
self.model_iter_fn(model, example_inputs)
except Exception as e:
print(f"Original Error: {str(e)}")
raise NotImplementedError("Eager model failed to run") from e
raise RuntimeError("Eager run failed") from e

def maybe_cast(self, model, example_inputs):
model = self.deepcopy_model(model)
Expand Down Expand Up @@ -2191,6 +2190,7 @@ def record_status(accuracy_status, dynamo_start_stats):
if isinstance(e, torch.cuda.OutOfMemoryError)
else "eager_2nd_run_fail"
)
log.exception(e)
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
Expand Down Expand Up @@ -2546,6 +2546,8 @@ def run_one_model(
name, model, example_inputs, optimize_ctx, experiment, tag
)
print(status)
torch.cuda.empty_cache()

if self.args.timing:
from torch._dynamo.utils import op_count, print_time_report
from torch.utils._stats import simple_call_counter
Expand Down Expand Up @@ -3498,6 +3500,31 @@ def run(runner, args, original_dir=None):
# Go back to main branch
repo.git.checkout(main_branch)
elif args.only:

def write_csv_when_exception(name: str, status: str, device=None):
print(status)
placeholder_batch_size = 0
devices = [device] if device is not None else args.devices
if args.accuracy:
headers = ["dev", "name", "batch_size", "accuracy"]
rows = [
[device, name, placeholder_batch_size, status] for device in devices
]
elif args.performance:
headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
rows = [
[device, name, placeholder_batch_size, 0.0, 0.0]
for device in devices
]
else:
headers = []
rows = [
[device, name, placeholder_batch_size, 0.0] for device in devices
]

for row in rows:
output_csv(output_filename, headers, row)

model_name = args.only
for device in args.devices:
batch_size = args.batch_size
Expand All @@ -3513,6 +3540,7 @@ def run(runner, args, original_dir=None):
torch.Tensor, lambda x: x.to(device=device), example_inputs
)
else:
name = model_name
try:
with tqdm(desc="loading model"):
extra_args = []
Expand Down Expand Up @@ -3567,12 +3595,18 @@ def run(runner, args, original_dir=None):
batch_size=batch_size,
extra_args=extra_args,
)
except NotImplementedError as e:
print(e)
except RuntimeError as e:
import traceback

mode = "train" if args.training else "eval"
print(f"{device:4} {mode:5} {name:34} ")
print(traceback.format_exc())
logging.warning("%s failed to load", args.only)
status = (
"model_fail_to_load"
if isinstance(e, NotImplementedError)
else "eager_fail_to_run"
)
write_csv_when_exception(name, status, device)
continue # bad benchmark implementation

if args.trace_on_xla:
Expand Down Expand Up @@ -3657,33 +3691,9 @@ def detect_and_mark_batch(t):
nmodels = len(model_names)
for i, name in enumerate(model_names):
current_name = name
placeholder_batch_size = 0
if args.progress:
print(f"Running model {i+1}/{nmodels}", flush=True)

def write_csv(status):
if args.accuracy:
headers = ["dev", "name", "batch_size", "accuracy"]
rows = [
[device, name, placeholder_batch_size, status]
for device in args.devices
]
elif args.performance:
headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
rows = [
[device, name, placeholder_batch_size, 0.0, 0.0]
for device in args.devices
]
else:
headers = []
rows = [
[device, name, placeholder_batch_size, 0.0]
for device in args.devices
]

for row in rows:
output_csv(output_filename, headers, row)

try:
timeout = args.timeout
if should_diff_branch(args):
Expand All @@ -3692,13 +3702,11 @@ def write_csv(status):
[sys.executable] + sys.argv + [f"--only={name}"], timeout=timeout
)
except subprocess.TimeoutExpired:
print("TIMEOUT", file=sys.stderr)
write_csv("timeout")
write_csv_when_exception(name, "timeout")
except subprocess.CalledProcessError as e:
print("Run failed with return code: ", e.returncode, file=sys.stderr)
print("Output: ", e.output, file=sys.stderr)
print("Error: ", e.stderr, file=sys.stderr)
write_csv("infra_error")
print_summary(output_filename, print_dataframe=args.print_dataframe_summary)


Expand Down

0 comments on commit bf8df16

Please sign in to comment.