Skip to content

Commit

Permalink
call empty_cache for dynamo tests (#126377)
Browse files Browse the repository at this point in the history
Summary:
When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models.

This PR unifies the `empty_cache` call for both CUDA and XPU.

X-link: pytorch/pytorch#126377
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire

Reviewed By: huydhn

Differential Revision: D57518757

fbshipit-source-id: a42ae31e7fb81bb05217fd672a3427bd68478a50
  • Loading branch information
Stonepia authored and facebook-github-bot committed May 18, 2024
1 parent b80b9cf commit fd05b86
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,24 @@ def deterministic_torch_manual_seed(*args, **kwargs):
torch.manual_seed = deterministic_torch_manual_seed


def empty_gpu_cache(device):
"""
Explicitly empty gpu cache to avoid OOM in subsequent run.
"""

if device not in ["cuda", "xpu"]:
log.warning(
"Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
device,
)
return

if device == "cuda":
torch.cuda.empty_cache()
elif device == "xpu":
torch.xpu.empty_cache()


def synchronize():
pass

Expand Down Expand Up @@ -2278,7 +2296,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
batch_size = initial_batch_size
while batch_size >= 1:
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
try:
device, name, model, example_inputs, _ = self.load_model(
device,
Expand Down Expand Up @@ -2468,7 +2486,7 @@ def record_status(accuracy_status, dynamo_start_stats):
fp64_outputs = None
finally:
del model_fp64, inputs_fp64
torch.cuda.empty_cache()
empty_gpu_cache(current_device)

tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
self.args.training, current_device, name
Expand Down Expand Up @@ -2497,7 +2515,7 @@ def record_status(accuracy_status, dynamo_start_stats):
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()
empty_gpu_cache(current_device)

# Rerun native pytorch
reset_rng_state()
Expand All @@ -2518,7 +2536,7 @@ def record_status(accuracy_status, dynamo_start_stats):
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()
empty_gpu_cache(current_device)

# Two eager runs should have exactly same result
is_same = True
Expand Down Expand Up @@ -2719,7 +2737,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
try:
if current_device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
t0 = time.perf_counter()
for _ in range(niters):
fn(model, example_inputs)
Expand Down Expand Up @@ -2949,7 +2967,7 @@ def run_one_model(
name, model, example_inputs, optimize_ctx, experiment, tag
)
print(status)
torch.cuda.empty_cache()
empty_gpu_cache(current_device)

self.maybe_preserve_compile_debug(name, status)

Expand Down

0 comments on commit fd05b86

Please sign in to comment.