Skip to content

Commit

Permalink
Optimize min/max/sum comprehensions C419 (#123960)
Browse files Browse the repository at this point in the history
Summary:
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

X-link: pytorch/pytorch#123960
Approved by: https://github.com/malfet

Reviewed By: PaliC

Differential Revision: D56119617

fbshipit-source-id: 9b25a3d55cd2666e27d0f78a9994ffeeffef7ba7

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Apr 15, 2024
1 parent 37bb048 commit eb28e08
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def reduce_to_scalar_loss(out):
# Mean does not work on integer tensors
return out.sum() / out.numel()
elif isinstance(out, (list, tuple)):
return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
elif type(out).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
Expand All @@ -115,7 +115,7 @@ def reduce_to_scalar_loss(out):
elif type(out).__name__ == "SquashedNormal":
return out.mean.sum()
elif isinstance(out, dict):
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
out.keys()
)
raise NotImplementedError("Don't know how to reduce", type(out))
Expand Down
2 changes: 1 addition & 1 deletion userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ def graph_break_report():

def recompilation_report():
if len(gf):
max_recompiles = max([num_recompiles(code) for code in gf])
max_recompiles = max(num_recompiles(code) for code in gf)
recomp_table = tabulate(
summarized_gf,
headers=["Function", "Recompiles", "Recompile Reasons"],
Expand Down

0 comments on commit eb28e08

Please sign in to comment.