diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index fcacf4c399..871eb05853 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -61,6 +61,11 @@ REQUIRE_HIGHER_FP16_TOLERANCE = { "drq", } +REQUIRE_HIGHER_BF16_TOLERANCE = { + "doctr_reco_predictor", + "drq", + "hf_Whisper", +} REQUIRE_COSINE_TOLERACE = { # Just keeping it here even though its empty, if we need this in future. } @@ -85,6 +90,10 @@ log = logging.getLogger(__name__) +class DummyGradScaler: + def scale(self, loss): + return loss + @contextmanager def nested(*contexts): """ @@ -370,7 +379,7 @@ def forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs= else: pred = mod(*cloned_inputs) loss = compute_loss(pred) - loss.backward(retain_graph=True) + DummyGradScaler().scale(loss).backward(retain_graph=True) optimizer_step(optimizer) if collect_outputs: return collect_results(mod, pred, loss, cloned_inputs) @@ -394,6 +403,11 @@ def get_tolerance_and_cosine_flag(model, is_training, current_device, name): if name in REQUIRE_HIGHER_FP16_TOLERANCE: return 1e-2, cosine return 1e-3, cosine + + if model.dargs.precision == "bf16": + if name in REQUIRE_HIGHER_BF16_TOLERANCE: + return 1e-2, cosine + if is_training and current_device == "cuda": tolerance = 1e-3 if name in REQUIRE_COSINE_TOLERACE: @@ -412,6 +426,8 @@ def skip_accuracy_check_as_eager_non_deterministic(is_training): def check_accuracy(tbmodel: 'torchbenchmark.util.model.BenchmarkModel') -> str: import torch import functools + from torch.utils._pytree import tree_map + from torch._dynamo.utils import same def _equal_nan_p(precision): equal_nan = True @@ -466,6 +482,12 @@ def maybe_cast(tbmodel, model, example_inputs): ) optimizer = init_optimizer(name, current_device, model_fp64.parameters(), is_training) fp64_outputs = run_n_iterations(model_fp64, inputs_fp64, contexts, optimizer, is_training) + fp64_outputs = tree_map( + lambda x: x.to(torch.float64) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x, + fp64_outputs, + ) except Exception: log.warning( "fp64 golden ref were not generated for %s. Setting accuracy check to cosine", @@ -473,6 +495,7 @@ def maybe_cast(tbmodel, model, example_inputs): ) tbmodel.dargs.use_cosine_similarity = True fp64_outputs = None + tolerance, cos_similarity = get_tolerance_and_cosine_flag( tbmodel, is_training, current_device, name ) @@ -512,6 +535,7 @@ def maybe_cast(tbmodel, model, example_inputs): else "eager_2nd_run_fail" ) return accuracy_status + # Two eager runs should have exactly same result is_same = True try: @@ -582,252 +606,3 @@ def maybe_cast(tbmodel, model, example_inputs): return accuracy_status return accuracy_status - -def istype(obj, allowed_types): - """isinstance() without subclasses""" - if isinstance(allowed_types, (tuple, list, set)): - return type(obj) in allowed_types - return type(obj) is allowed_types - -def is_numpy_int_type(value): - if HAS_NUMPY: - import numpy as np - return istype( - value, - ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ), - ) - else: - return False - - -def is_numpy_float_type(value): - if HAS_NUMPY: - import numpy as np - return istype( - value, - ( - np.float16, - np.float32, - np.float64, - ), - ) - else: - return False - - -def is_numpy_ndarray(value): - if HAS_NUMPY: - import numpy as np - return istype(value, np.ndarray) - else: - return False - - -def rmse(ref, res): - """ - Calculate root mean squared error - """ - import torch - return torch.sqrt(torch.mean(torch.square(ref - res))) - -def same( - ref, - res, - fp64_ref=None, - cos_similarity=False, - tol=1e-4, - equal_nan=False, - exact_dtype=True, - relax_numpy_equality=False, - ignore_non_fp=False, - log_error=log.error, -): - """Check correctness to see if ref and res match""" - import math - import torch - if fp64_ref is None: - fp64_ref = ref - if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): - assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" - return len(ref) == len(res) and all( - same( - ai, - bi, - fp64_refi, - cos_similarity, - tol, - equal_nan, - exact_dtype, - relax_numpy_equality, - ignore_non_fp, - log_error=log_error, - ) - for ai, bi, fp64_refi in zip(ref, res, fp64_ref) - ) - elif isinstance(ref, dict): - assert isinstance(res, dict) - assert set(ref.keys()) == set( - res.keys() - ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}" - for k in sorted(ref.keys()): - if not ( - same( - ref[k], - res[k], - fp64_ref[k], - cos_similarity=cos_similarity, - tol=tol, - equal_nan=equal_nan, - exact_dtype=exact_dtype, - relax_numpy_equality=relax_numpy_equality, - ignore_non_fp=ignore_non_fp, - log_error=log_error, - ) - ): - log_error("Accuracy failed for key name %s", k) - return False - return True - elif isinstance(ref, torch.Tensor): - assert not isinstance(ref, torch._subclasses.FakeTensor) - assert not isinstance(res, torch._subclasses.FakeTensor) - - if ref.is_sparse: - assert res.is_sparse - ref = ref.to_dense() - res = res.to_dense() - assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" - if exact_dtype: - if ref.dtype != res.dtype: - log_error("dtype mismatch %s, %s", ref.dtype, res.dtype) - return False - if ref.dtype == torch.bool: - if ignore_non_fp: - return True - # triton stores bool as int8, so add this for more accurate checking - r = torch.allclose( - ref.to(dtype=torch.uint8), - res.to(dtype=torch.uint8), - atol=tol, - rtol=tol, - equal_nan=equal_nan, - ) - if not r: - log_error("Accuracy failed: uint8 tensor did not match") - return r - if cos_similarity: - ref = ref.flatten().to(torch.float32) - res = res.flatten().to(torch.float32) - if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): - # early exit that handles zero/nan better - # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 - return True - score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) - if score < 0.99: - log.warning("Similarity score=%s", score.cpu().detach().item()) - return score >= 0.99 - else: - if not exact_dtype: - ref = ref.to(res.dtype) - - # First try usual allclose - if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): - return True - - # Check error from fp64 version - if fp64_ref.dtype == torch.float64: - ref_error = rmse(fp64_ref, ref).item() - res_error = rmse(fp64_ref, res).item() - multiplier = 2.0 - - if ( - fp64_ref.numel() < 1000 - or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) - # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE - or tol >= 2 * 1e-2 - ): - # In the presence of noise, noise might dominate our error - # metric for smaller tensors. - # Similary, for 1x1 kernels, there seems to be high noise with amp. - multiplier = 3.0 - - passes_test = res_error <= (multiplier * ref_error + tol / 10.0) - if not passes_test: - log_error( - "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s", - res_error, - ref_error, - res.size(), - ) - # import pdb; pdb.set_trace() - return passes_test - - if ignore_non_fp: - return True - - log_error("Accuracy failed: allclose not within tol=%s", tol) - return False - elif isinstance(ref, (str, int, type(None), bool, torch.device)): - if ignore_non_fp: - return True - r = ref == res - if not r: - log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res) - return r - elif isinstance(ref, float): - r = math.isclose(ref, res, rel_tol=tol, abs_tol=tol) - if not r: - log_error( - "Accuracy failed (float): %s != %s (within tol=%s)", ref, res, tol - ) - return r - elif is_numpy_int_type(ref) or is_numpy_float_type(ref): - if relax_numpy_equality and not ( - is_numpy_int_type(res) or is_numpy_float_type(res) - ): - ref = ref.item() - r = (type(ref) is type(res)) and (ref == res) - if not r: - log_error("Accuracy failed (numpy): %s != %s", ref, res) - return r - elif is_numpy_ndarray(ref): - return (type(ref) is type(res)) and (ref == res).all() - elif type(ref).__name__ in ( - "MaskedLMOutput", - "Seq2SeqLMOutput", - "CausalLMOutputWithCrossAttentions", - "LongformerMaskedLMOutput", - "Instances", - "SquashedNormal", - "Boxes", - "Normal", - "TanhTransform", - "Foo", - "Variable", - ): - assert type(ref) is type(res) - return all( - same( - getattr(ref, key), - getattr(res, key), - getattr(fp64_ref, key), - cos_similarity=cos_similarity, - tol=tol, - equal_nan=equal_nan, - exact_dtype=exact_dtype, - relax_numpy_equality=relax_numpy_equality, - ignore_non_fp=ignore_non_fp, - log_error=log_error, - ) - for key in ref.__dict__.keys() - ) - else: - raise RuntimeError(f"unsupported type: {type(ref).__name__}")