Skip to content

Commit

Permalink
Consistent accuracy results with dynamobench (#1941)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#110189

Pull Request resolved: #1941

Use the upstream `torch._dynamo.same` function in accuracy checking and remove the self-hosted version in torchbench.

Now cmf_10x and ads_dhen_5x can run in deterministic mode, enable deepcopy and deterministic mode.

Reviewed By: desertfire, jackiexu1992, mengluy0125

Differential Revision: D49639733

fbshipit-source-id: c2cdc6c57ce6e5190d66a8201fb83493002a2c68
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Sep 28, 2023
1 parent aaa8baf commit 3480590
Showing 1 changed file with 25 additions and 250 deletions.
275 changes: 25 additions & 250 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
REQUIRE_HIGHER_FP16_TOLERANCE = {
"drq",
}
REQUIRE_HIGHER_BF16_TOLERANCE = {
"doctr_reco_predictor",
"drq",
"hf_Whisper",
}
REQUIRE_COSINE_TOLERACE = {
# Just keeping it here even though its empty, if we need this in future.
}
Expand All @@ -85,6 +90,10 @@

log = logging.getLogger(__name__)

class DummyGradScaler:
def scale(self, loss):
return loss

@contextmanager
def nested(*contexts):
"""
Expand Down Expand Up @@ -370,7 +379,7 @@ def forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs=
else:
pred = mod(*cloned_inputs)
loss = compute_loss(pred)
loss.backward(retain_graph=True)
DummyGradScaler().scale(loss).backward(retain_graph=True)
optimizer_step(optimizer)
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
Expand All @@ -394,6 +403,11 @@ def get_tolerance_and_cosine_flag(model, is_training, current_device, name):
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
return 1e-2, cosine
return 1e-3, cosine

if model.dargs.precision == "bf16":
if name in REQUIRE_HIGHER_BF16_TOLERANCE:
return 1e-2, cosine

if is_training and current_device == "cuda":
tolerance = 1e-3
if name in REQUIRE_COSINE_TOLERACE:
Expand All @@ -412,6 +426,8 @@ def skip_accuracy_check_as_eager_non_deterministic(is_training):
def check_accuracy(tbmodel: 'torchbenchmark.util.model.BenchmarkModel') -> str:
import torch
import functools
from torch.utils._pytree import tree_map
from torch._dynamo.utils import same

def _equal_nan_p(precision):
equal_nan = True
Expand Down Expand Up @@ -466,13 +482,20 @@ def maybe_cast(tbmodel, model, example_inputs):
)
optimizer = init_optimizer(name, current_device, model_fp64.parameters(), is_training)
fp64_outputs = run_n_iterations(model_fp64, inputs_fp64, contexts, optimizer, is_training)
fp64_outputs = tree_map(
lambda x: x.to(torch.float64)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
fp64_outputs,
)
except Exception:
log.warning(
"fp64 golden ref were not generated for %s. Setting accuracy check to cosine",
tbmodel.name,
)
tbmodel.dargs.use_cosine_similarity = True
fp64_outputs = None

tolerance, cos_similarity = get_tolerance_and_cosine_flag(
tbmodel, is_training, current_device, name
)
Expand Down Expand Up @@ -512,6 +535,7 @@ def maybe_cast(tbmodel, model, example_inputs):
else "eager_2nd_run_fail"
)
return accuracy_status

# Two eager runs should have exactly same result
is_same = True
try:
Expand Down Expand Up @@ -582,252 +606,3 @@ def maybe_cast(tbmodel, model, example_inputs):
return accuracy_status

return accuracy_status

def istype(obj, allowed_types):
"""isinstance() without subclasses"""
if isinstance(allowed_types, (tuple, list, set)):
return type(obj) in allowed_types
return type(obj) is allowed_types

def is_numpy_int_type(value):
if HAS_NUMPY:
import numpy as np
return istype(
value,
(
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
)
else:
return False


def is_numpy_float_type(value):
if HAS_NUMPY:
import numpy as np
return istype(
value,
(
np.float16,
np.float32,
np.float64,
),
)
else:
return False


def is_numpy_ndarray(value):
if HAS_NUMPY:
import numpy as np
return istype(value, np.ndarray)
else:
return False


def rmse(ref, res):
"""
Calculate root mean squared error
"""
import torch
return torch.sqrt(torch.mean(torch.square(ref - res)))

def same(
ref,
res,
fp64_ref=None,
cos_similarity=False,
tol=1e-4,
equal_nan=False,
exact_dtype=True,
relax_numpy_equality=False,
ignore_non_fp=False,
log_error=log.error,
):
"""Check correctness to see if ref and res match"""
import math
import torch
if fp64_ref is None:
fp64_ref = ref
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
return len(ref) == len(res) and all(
same(
ai,
bi,
fp64_refi,
cos_similarity,
tol,
equal_nan,
exact_dtype,
relax_numpy_equality,
ignore_non_fp,
log_error=log_error,
)
for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
)
elif isinstance(ref, dict):
assert isinstance(res, dict)
assert set(ref.keys()) == set(
res.keys()
), f"keys mismatch {set(ref.keys())} == {set(res.keys())}"
for k in sorted(ref.keys()):
if not (
same(
ref[k],
res[k],
fp64_ref[k],
cos_similarity=cos_similarity,
tol=tol,
equal_nan=equal_nan,
exact_dtype=exact_dtype,
relax_numpy_equality=relax_numpy_equality,
ignore_non_fp=ignore_non_fp,
log_error=log_error,
)
):
log_error("Accuracy failed for key name %s", k)
return False
return True
elif isinstance(ref, torch.Tensor):
assert not isinstance(ref, torch._subclasses.FakeTensor)
assert not isinstance(res, torch._subclasses.FakeTensor)

if ref.is_sparse:
assert res.is_sparse
ref = ref.to_dense()
res = res.to_dense()
assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}"
if exact_dtype:
if ref.dtype != res.dtype:
log_error("dtype mismatch %s, %s", ref.dtype, res.dtype)
return False
if ref.dtype == torch.bool:
if ignore_non_fp:
return True
# triton stores bool as int8, so add this for more accurate checking
r = torch.allclose(
ref.to(dtype=torch.uint8),
res.to(dtype=torch.uint8),
atol=tol,
rtol=tol,
equal_nan=equal_nan,
)
if not r:
log_error("Accuracy failed: uint8 tensor did not match")
return r
if cos_similarity:
ref = ref.flatten().to(torch.float32)
res = res.flatten().to(torch.float32)
if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True):
# early exit that handles zero/nan better
# cosine_similarity(zeros(10), zeros(10), dim=0) is 0
return True
score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6)
if score < 0.99:
log.warning("Similarity score=%s", score.cpu().detach().item())
return score >= 0.99
else:
if not exact_dtype:
ref = ref.to(res.dtype)

# First try usual allclose
if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan):
return True

# Check error from fp64 version
if fp64_ref.dtype == torch.float64:
ref_error = rmse(fp64_ref, ref).item()
res_error = rmse(fp64_ref, res).item()
multiplier = 2.0

if (
fp64_ref.numel() < 1000
or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
# large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
or tol >= 2 * 1e-2
):
# In the presence of noise, noise might dominate our error
# metric for smaller tensors.
# Similary, for 1x1 kernels, there seems to be high noise with amp.
multiplier = 3.0

passes_test = res_error <= (multiplier * ref_error + tol / 10.0)
if not passes_test:
log_error(
"RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s",
res_error,
ref_error,
res.size(),
)
# import pdb; pdb.set_trace()
return passes_test

if ignore_non_fp:
return True

log_error("Accuracy failed: allclose not within tol=%s", tol)
return False
elif isinstance(ref, (str, int, type(None), bool, torch.device)):
if ignore_non_fp:
return True
r = ref == res
if not r:
log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res)
return r
elif isinstance(ref, float):
r = math.isclose(ref, res, rel_tol=tol, abs_tol=tol)
if not r:
log_error(
"Accuracy failed (float): %s != %s (within tol=%s)", ref, res, tol
)
return r
elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
if relax_numpy_equality and not (
is_numpy_int_type(res) or is_numpy_float_type(res)
):
ref = ref.item()
r = (type(ref) is type(res)) and (ref == res)
if not r:
log_error("Accuracy failed (numpy): %s != %s", ref, res)
return r
elif is_numpy_ndarray(ref):
return (type(ref) is type(res)) and (ref == res).all()
elif type(ref).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
"CausalLMOutputWithCrossAttentions",
"LongformerMaskedLMOutput",
"Instances",
"SquashedNormal",
"Boxes",
"Normal",
"TanhTransform",
"Foo",
"Variable",
):
assert type(ref) is type(res)
return all(
same(
getattr(ref, key),
getattr(res, key),
getattr(fp64_ref, key),
cos_similarity=cos_similarity,
tol=tol,
equal_nan=equal_nan,
exact_dtype=exact_dtype,
relax_numpy_equality=relax_numpy_equality,
ignore_non_fp=ignore_non_fp,
log_error=log_error,
)
for key in ref.__dict__.keys()
)
else:
raise RuntimeError(f"unsupported type: {type(ref).__name__}")

0 comments on commit 3480590

Please sign in to comment.