Skip to content

Commit

Permalink
apply autofix for ruff rules
Browse files Browse the repository at this point in the history
Summary:
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

X-link: pytorch/pytorch#130199
Approved by: https://github.com/malfet

Reviewed By: izaitsevfb

Differential Revision: D59661447

fbshipit-source-id: 8027c90d4fd82bf3826197a4bdbf4ad5674394ff
  • Loading branch information
XuehaiPan authored and facebook-github-bot committed Jul 13, 2024
1 parent b74c016 commit 59596d4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def collect_results(model, prediction, loss, example_inputs):
# f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
# )

grads = dict()
params = dict()
grads = {}
params = {}
for name, param in model.named_parameters():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
Expand All @@ -71,7 +71,7 @@ def collect_results(model, prediction, loss, example_inputs):
params[name] = param_copy
results.append(grads)
results.append(params)
buffers = dict()
buffers = {}
for name, buffer in model.named_buffers():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
Expand Down
14 changes: 7 additions & 7 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
np.random: tnp.random,
}
else:
NP_SUPPORTED_MODULES = tuple()
NP_SUPPORTED_MODULES = ()

NP_TO_TNP_MODULE = {}
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
Expand Down Expand Up @@ -463,8 +463,8 @@ class ExactWeakKeyDictionary:
"""Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""

def __init__(self):
self.values = dict()
self.refs = dict()
self.values = {}
self.refs = {}

def __getitem__(self, key):
return self.values[id(key)]
Expand Down Expand Up @@ -1144,10 +1144,10 @@ def check_numpy_ndarray_args(args, kwargs):
)


dict_keys: Type[KeysView[Any]] = type(dict().keys())
dict_values: Type[ValuesView[Any]] = type(dict().values())
dict_keys: Type[KeysView[Any]] = type({}.keys())
dict_values: Type[ValuesView[Any]] = type({}.values())
odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values())
tuple_iterator: Type[Iterator[Any]] = type(iter(tuple()))
tuple_iterator: Type[Iterator[Any]] = type(iter(()))
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
object_new = object.__new__

Expand Down Expand Up @@ -1610,7 +1610,7 @@ def disable_cache_limit():
guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list)

# Keep a record of graph break reasons for logging
graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = list()
graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = []

# keep record of compiled code, if we are in "error if recompile"
# to track code that dynamo has compiled previously
Expand Down
2 changes: 1 addition & 1 deletion userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ def try_script(model, example_inputs):


class AOTInductorModelCache:
cache = dict()
cache = {}

@classmethod
def load(cls, model, example_inputs, device):
Expand Down
4 changes: 2 additions & 2 deletions userbenchmark/dynamo/dynamobench/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def process_hf_reformer_output(out):
# combination of models supported by HF Fx parser and some manually supplied
# models. For these models, we already know the largest batch size that can fit
# on A100 GPUs - 40 GB.
BATCH_SIZE_KNOWN_MODELS = dict()
BATCH_SIZE_KNOWN_MODELS = {}


# Get the list of models and their batch sizes
Expand Down Expand Up @@ -619,7 +619,7 @@ def refresh_model_names_and_batch_sizes():
"""
import transformers.utils.fx as hf_fx

family = dict()
family = {}
lm_seen = set()
family_seen = set()
for cls_name in hf_fx._SUPPORTED_MODELS:
Expand Down
4 changes: 2 additions & 2 deletions userbenchmark/dynamo/dynamobench/timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def pip_install(package):
from timm.data import resolve_data_config
from timm.models import create_model

TIMM_MODELS = dict()
TIMM_MODELS = {}
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")

with open(filename) as fh:
Expand Down Expand Up @@ -174,7 +174,7 @@ def get_family_name(name):
return name.split("_")[0]

def populate_family(models):
family = dict()
family = {}
for model_name in models:
family_name = get_family_name(model_name)
if family_name not in family:
Expand Down

0 comments on commit 59596d4

Please sign in to comment.