From fd05b8643cd2b9c4a092343e261aef88ff07e4bf Mon Sep 17 00:00:00 2001 From: Stonepia Date: Fri, 17 May 2024 20:47:42 -0700 Subject: [PATCH] call empty_cache for dynamo tests (#126377) Summary: When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models. This PR unifies the `empty_cache` call for both CUDA and XPU. X-link: https://github.com/pytorch/pytorch/pull/126377 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire Reviewed By: huydhn Differential Revision: D57518757 fbshipit-source-id: a42ae31e7fb81bb05217fd672a3427bd68478a50 --- userbenchmark/dynamo/dynamobench/common.py | 30 +++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 096dbc48ec..2b877e4344 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -354,6 +354,24 @@ def deterministic_torch_manual_seed(*args, **kwargs): torch.manual_seed = deterministic_torch_manual_seed +def empty_gpu_cache(device): + """ + Explicitly empty gpu cache to avoid OOM in subsequent run. + """ + + if device not in ["cuda", "xpu"]: + log.warning( + "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]", + device, + ) + return + + if device == "cuda": + torch.cuda.empty_cache() + elif device == "xpu": + torch.xpu.empty_cache() + + def synchronize(): pass @@ -2278,7 +2296,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): def batch_size_finder(self, device, model_name, initial_batch_size=1024): batch_size = initial_batch_size while batch_size >= 1: - torch.cuda.empty_cache() + empty_gpu_cache(current_device) try: device, name, model, example_inputs, _ = self.load_model( device, @@ -2468,7 +2486,7 @@ def record_status(accuracy_status, dynamo_start_stats): fp64_outputs = None finally: del model_fp64, inputs_fp64 - torch.cuda.empty_cache() + empty_gpu_cache(current_device) tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( self.args.training, current_device, name @@ -2497,7 +2515,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Rerun native pytorch reset_rng_state() @@ -2518,7 +2536,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Two eager runs should have exactly same result is_same = True @@ -2719,7 +2737,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): try: if current_device == "cuda": torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + empty_gpu_cache(current_device) t0 = time.perf_counter() for _ in range(niters): fn(model, example_inputs) @@ -2949,7 +2967,7 @@ def run_one_model( name, model, example_inputs, optimize_ctx, experiment, tag ) print(status) - torch.cuda.empty_cache() + empty_gpu_cache(current_device) self.maybe_preserve_compile_debug(name, status)