From eb28e08c6c23958b066ebdbdaab02c146025f032 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 15 Apr 2024 09:48:42 -0700 Subject: [PATCH] Optimize min/max/sum comprehensions C419 (#123960) Summary: Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied. X-link: https://github.com/pytorch/pytorch/pull/123960 Approved by: https://github.com/malfet Reviewed By: PaliC Differential Revision: D56119617 fbshipit-source-id: 9b25a3d55cd2666e27d0f78a9994ffeeffef7ba7 Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- userbenchmark/dynamo/dynamobench/_dynamo/testing.py | 4 ++-- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py index 12b545d4d1..c46304369b 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py @@ -105,7 +105,7 @@ def reduce_to_scalar_loss(out): # Mean does not work on integer tensors return out.sum() / out.numel() elif isinstance(out, (list, tuple)): - return sum([reduce_to_scalar_loss(x) for x in out]) / len(out) + return sum(reduce_to_scalar_loss(x) for x in out) / len(out) elif type(out).__name__ in ( "MaskedLMOutput", "Seq2SeqLMOutput", @@ -115,7 +115,7 @@ def reduce_to_scalar_loss(out): elif type(out).__name__ == "SquashedNormal": return out.mean.sum() elif isinstance(out, dict): - return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len( + return sum(reduce_to_scalar_loss(value) for value in out.values()) / len( out.keys() ) raise NotImplementedError("Don't know how to reduce", type(out)) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 25ae60d361..9f1f3aa3a7 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1594,7 +1594,7 @@ def graph_break_report(): def recompilation_report(): if len(gf): - max_recompiles = max([num_recompiles(code) for code in gf]) + max_recompiles = max(num_recompiles(code) for code in gf) recomp_table = tabulate( summarized_gf, headers=["Function", "Recompiles", "Recompile Reasons"],