diff --git a/lib/python/pyflyby/_autoimp.py b/lib/python/pyflyby/_autoimp.py index e2f5a93f..82bb57a1 100644 --- a/lib/python/pyflyby/_autoimp.py +++ b/lib/python/pyflyby/_autoimp.py @@ -878,14 +878,17 @@ def _visit_Store(self, fullname, value=None): def _remove_from_missing_imports(self, fullname): for missing_import in self.missing_imports: # If it was defined inside a class method, then it wouldn't have been added to - # the missing imports anyways. + # the missing imports anyways (except in that case of annotations) # See the following tests: # - tests.test_autoimp.test_method_reference_current_class # - tests.test_autoimp.test_find_missing_imports_class_name_1 # - tests.test_autoimp.test_scan_for_import_issues_class_defined_after_use + scopestack = missing_import[1].scope_info['scopestack'] + in_class_scope = isinstance(scopestack[-1], _ClassScope) inside_class = missing_import[1].scope_info.get('_in_class_def') - if missing_import[1].startswith(fullname) and not inside_class: - self.missing_imports.remove(missing_import) + if missing_import[1].startswith(fullname): + if in_class_scope or not inside_class: + self.missing_imports.remove(missing_import) def _get_scope_info(self): return { diff --git a/tests/test_autoimp.py b/tests/test_autoimp.py index 9fe0cf55..448d078e 100644 --- a/tests/test_autoimp.py +++ b/tests/test_autoimp.py @@ -521,20 +521,58 @@ def foo(): assert unused == [] +def test_annotation_inside_class(): + code = dedent( + """ + class A: + param1: str + param2: B + + class B: + param1: str + """ + ) + missing, unused = scan_for_import_issues(code, [{}]) + assert missing == [] + assert unused == [] + +@pytest.mark.xfail( + reason="Had to deactivate as part of https://github.com/deshaw/pyflyby/pull/269/files conflicting requirements" +) def test_find_missing_imports_class_name_1(): code = dedent( """ - class Corinne(object): + class Corinne: pass - class Bobtail(object): - class Chippewa(object): - Bobtail + class Bobtail: + class Chippewa: + Bobtail # will be name error at runtime Rockton = Passall, Corinne, Chippewa - """) - result = find_missing_imports(code, [{}]) - result = _dilist2strlist(result) - expected = ['Bobtail', 'Passall'] + # ^error, ^ok , ^ok + """ + ) + result = find_missing_imports(code, [{}]) + result = _dilist2strlist(result) + expected = ["Bobtail", "Passall"] + assert expected == result + + +def test_find_missing_imports_class_name_1b(): + code = dedent( + """ + class Corinne: + pass + class Bobtail: + class Chippewa: + Bobtail # will be name error at runtime + Rockton = Passall, Corinne, Chippewa + # ^error, ^ok , ^ok + """ + ) + result = find_missing_imports(code, [{}]) + result = _dilist2strlist(result) + expected = ["Passall"] assert expected == result diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 8080f817..74271047 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -13,7 +13,7 @@ import tempfile from textwrap import dedent -from pyflyby._util import EnvVarCtx +from pyflyby._util import EnvVarCtx, CwdCtx import pytest @@ -764,3 +764,49 @@ def test_tidy_imports_sorting(): sympy """).strip().format(f=f) assert result == expected + + +def test_tidy_imports_forward_references(): + with tempfile.TemporaryDirectory() as temp_dir: + foo = os.path.join(temp_dir, "foo.py") + with open(foo, "w") as foo_fp: + foo_fp.write(dedent(""" + from __future__ import annotations + + + class A: + param1: str + param2: B + + + class B: + param1: str + """).lstrip()) + foo_fp.flush() + + dot_pyflyby = os.path.join(temp_dir, ".pyflyby") + with open(dot_pyflyby, "w") as dot_pyflyby_fp: + dot_pyflyby_fp.write(dedent(""" + from foo import A, B + """).lstrip()) + dot_pyflyby_fp.flush() + with CwdCtx(temp_dir): + result = pipe( + [BIN_DIR + "/tidy-imports", foo_fp.name], + env={"PYFLYBY_PATH": dot_pyflyby}, + ) + + expected = dedent( + """ + from __future__ import annotations + + class A: + param1: str + param2: B + + + class B: + param1: str + """ + ).strip() + assert result == expected