Skip to content

Commit

Permalink
Make the speedup benchmark works with DDP + CompiledAutograd
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#120454

With DDP + CompiledAutograd, we could not use the same parallelized model to do the test. This PR copies the model.
ghstack-source-id: 217034133
exported-using-ghexport

Reviewed By: xmfan

Differential Revision: D54094257

fbshipit-source-id: 29fb31f9653a50d1a1e5f5ff398a668b7e4209e5
  • Loading branch information
fegin authored and facebook-github-bot committed Mar 1, 2024
1 parent 993bc7b commit 38ab769
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,9 @@ def maybe_mark_profile(*args, **kwargs):
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
args.compiled_autograd
):
compiled_model = kwargs.get("compiled_model", model)
timings[rep, 1], actual_output = timed(
model,
compiled_model,
frozen_model_iter_fn,
inputs,
return_result=True,
Expand Down Expand Up @@ -740,11 +741,16 @@ def maybe_mark_profile(*args, **kwargs):
for k, v in kwargs["dynamo_stats"].items():
headers.append(k)
row.append(v)
output_csv(
output_filename,
headers,
row,
)
if (
not torch.distributed.is_available() # no distributed is built
or not torch.distributed.is_initialized() # single gpu
or torch.distributed.get_rank() == 0 # distributed + rank0
):
output_csv(
output_filename,
headers,
row,
)
headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
assert (
output_filename.find(".csv") > 0
Expand Down Expand Up @@ -2643,10 +2649,15 @@ def warmup(fn, model, example_inputs, mode, niters=5):
return latency, peak_mem, dynamo_stats

# Cast the model to float16/float32 as necessary
model, example_inputs = self.maybe_cast(model, example_inputs)
orig_model, example_inputs = self.maybe_cast(model, example_inputs)

# Use distributed wrapping as necessary
model = self.deepcopy_and_maybe_parallelize(model)
model = self.deepcopy_and_maybe_parallelize(orig_model)
if experiment.func is speedup_experiment:
# If DDP + compiler is enabled, we need to use a different
compiled_model = self.deepcopy_and_maybe_parallelize(orig_model)
else:
compiled_model = model

self.init_optimizer(name, current_device, model.parameters())
with self.pick_grad(name, self.args.training):
Expand All @@ -2670,7 +2681,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):

with maybe_enable_compiled_autograd(self.args.compiled_autograd):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"
optimized_model_iter_fn, compiled_model, example_inputs, "dynamo"
)

compilation_time = dynamo_latency - eager_latency + aot_compilation_time
Expand All @@ -2696,10 +2707,10 @@ def warmup(fn, model, example_inputs, mode, niters=5):
results = []
# run with torch._dynamo few times to populate the cache
for _ in range(3):
optimized_model_iter_fn(model, example_inputs)
optimized_model_iter_fn(compiled_model, example_inputs)
_, frames_second_pass = Stats.reset_counters() # should be 0
if frames_second_pass > 0:
optimized_model_iter_fn(model, example_inputs)
optimized_model_iter_fn(compiled_model, example_inputs)
_, frames_third_pass = Stats.reset_counters() # should be 0
else:
frames_third_pass = 0
Expand All @@ -2715,6 +2726,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):

if not hasattr(model, name):
model.name = name
experiment_kwargs["compiled_model"] = compiled_model
results.append(experiment(model, example_inputs, **experiment_kwargs))
return " ".join(map(str, results))

Expand Down

0 comments on commit 38ab769

Please sign in to comment.