diff --git a/bugbear.py b/bugbear.py index f266118..e0434d0 100644 --- a/bugbear.py +++ b/bugbear.py @@ -962,13 +962,18 @@ def _get_names_from_tuple(self, node: ast.Tuple): elif isinstance(dim, ast.Tuple): yield from self._get_names_from_tuple(dim) - def _get_dict_comp_loop_var_names(self, node: ast.DictComp): + def _get_dict_comp_loop_and_named_expr_var_names(self, node: ast.DictComp): + finder = NamedExprFinder() for gen in node.generators: if isinstance(gen.target, ast.Name): yield gen.target.id elif isinstance(gen.target, ast.Tuple): yield from self._get_names_from_tuple(gen.target) + finder.visit(gen.ifs) + + yield from finder.names.keys() + def check_for_b035(self, node: ast.DictComp): """Check that a static key isn't used in a dict comprehension. @@ -980,7 +985,9 @@ def check_for_b035(self, node: ast.DictComp): B035(node.key.lineno, node.key.col_offset, vars=(node.key.value,)) ) elif isinstance(node.key, ast.Name): - if node.key.id not in self._get_dict_comp_loop_var_names(node): + if node.key.id not in self._get_dict_comp_loop_and_named_expr_var_names( + node + ): self.errors.append( B035(node.key.lineno, node.key.col_offset, vars=(node.key.id,)) ) @@ -1539,6 +1546,30 @@ def visit(self, node): return node +@attr.s +class NamedExprFinder(ast.NodeVisitor): + """Finds names defined through an ast.NamedExpr. + + After `.visit(node)` is called, `found` is a dict with all name nodes inside, + key is name string, value is the node (useful for location purposes). + """ + + names: Dict[str, List[ast.Name]] = attr.ib(default=attr.Factory(dict)) + + def visit_NamedExpr(self, node: ast.NamedExpr): + self.names.setdefault(node.target.id, []).append(node.target) + self.generic_visit(node) + + def visit(self, node): + """Like super-visit but supports iteration over lists.""" + if not isinstance(node, list): + return super().visit(node) + + for elem in node: + super().visit(elem) + return node + + class FuntionDefDefaultsVisitor(ast.NodeVisitor): def __init__(self, b008_extend_immutable_calls=None): self.b008_extend_immutable_calls = b008_extend_immutable_calls or set() diff --git a/tests/b035.py b/tests/b035.py index 1a451e3..dd41d66 100644 --- a/tests/b035.py +++ b/tests/b035.py @@ -33,3 +33,19 @@ # bad - variabe not from generator v3 = 1 bad_var_not_from_nested_tuple = {v3: k for k, (v1, v2) in {"a": (1, 2)}.items()} + +# OK - variable from named expression +var_from_named_expr = { + k: v + for v in {"key": "foo", "data": {}} + if (k := v.get("key")) is not None +} + +# nested generators with named expressions +var_from_named_expr_nested = { + k: v + for v in {"keys": [{"key": "foo"}], "data": {}} + if (keys := v.get("keys")) is not None + for item in keys + if (k := item.get("key")) is not None +}