diff --git a/test/test_ops.py b/test/test_ops.py index 5619b28..432d3ce 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -232,8 +232,9 @@ def pyg_fn(): fn = pyg_fn if engine == "pyg" else fasten_fn fn() flops = get_matmul_flops(tensor_slice, other) + flops = 2 * flops if phase == "backward" else flops bytes = get_matmul_bytes(tensor_slice, other) - with proton.scope(f"random_{phase}_{engine}_{K}_{T}", metrics={"flops32": flops, "bytes": bytes}): + with proton.scope(f"random_{phase}_{engine}_{K}_{T}", metrics={"flops": flops, "bytes": bytes}): fn()