From a2dc5e1ccc61f972343a270c8a2a499d02e74eeb Mon Sep 17 00:00:00 2001 From: cyyc1 <114281716+cyyc1@users.noreply.github.com> Date: Tue, 4 Oct 2022 02:13:38 -0700 Subject: [PATCH] Deal with *vararg, **kwarg, and * in arg list (#1) --- .github/workflows/python-package.yml | 2 +- README.md | 34 ++-- flake8_indent_in_def.py | 243 +++++++++++++++++++++++++-- setup.cfg | 2 +- tests/not_ok_cases.py | 183 ++++++++++++++++++++ tests/ok_cases.py | 18 ++ tests/test_flake8_indent_in_def.py | 16 +- 7 files changed, 467 insertions(+), 31 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 24ba924..85e5799 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -35,7 +35,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # "--exit-zero" treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --max-complexity=10 --max-line-length=80 --statistics + flake8 . --count --max-complexity=10 --max-line-length=88 --statistics - name: Test with pytest run: | pytest diff --git a/README.md b/README.md index 15919ed..e947241 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,15 @@ There is one violation code that this plugin reports: ### _Wrong_ -This plugin, as well as [PEP8](https://peps.python.org/pep-0008/#indentation), considers the following indentation styles wrong: +This plugin considers the following indentation styles wrong: + +```python +def some_function(arg1, + *, + arg2, + arg3): + print(arg1) +``` ```python def some_function(argument1, @@ -33,12 +41,7 @@ def some_function(argument1, print(argument1) ``` -```python -def some_function(arg1, - arg2, - arg3): - print(arg1) -``` +This following style above is the style choice of the [`black` formatter](https://github.com/psf/black). Both this plugin and [PEP8](https://peps.python.org/pep-0008/#indentation) consider it wrong because arguments and function names would be difficult to visually distinghish. ```python def some_function( @@ -49,8 +52,6 @@ def some_function( print(arg1) ``` -Note: this style above is the style choice of the [`black` formatter](https://github.com/psf/black). This style is wrong because arguments and function names would be difficult to visually distinghish. - ### _Correct_ Correspondingly, here are the correct indentation styles: @@ -59,6 +60,7 @@ Correspondingly, here are the correct indentation styles: def some_function( arg1: int, arg2: list, + *, arg3: bool = None, ): print(arg1) @@ -111,9 +113,19 @@ def some_func( ## Rationale -When we only indent by 4 spaces in function definitions, it is difficult to visually distinguish function arguments with the function name and the function body. This reduces readability. +When we only indent by 4 spaces in function definitions, it is difficult to visually distinguish function arguments with the function name and the function body. This reduces readability. It is similar for base classes in class definitions, but it's less of an issue than function definitions. -It is similar for base classes in class definitions, but it's less of an issue than function definitions. +Specifically, the following style is allowed by PEP8 but this plugin still consider it wrong, because it could lead to messy code diff when refactoring: + +```diff +- def some_very_very_very_very_long_func(arg1, ++ def refactored_function_name(arg1, + arg2, + arg3, + +): + return None +``` ## Interaction with other style checkers and formatters diff --git a/flake8_indent_in_def.py b/flake8_indent_in_def.py index c1d4b16..164661b 100644 --- a/flake8_indent_in_def.py +++ b/flake8_indent_in_def.py @@ -1,5 +1,7 @@ import ast -from typing import Generator, Tuple, Type, Any, List, Union +import tokenize +from enum import Enum +from typing import Generator, Tuple, Type, Any, List, Union, Dict, Optional import importlib.metadata as importlib_metadata @@ -17,22 +19,83 @@ EXPECTED_INDENT = 8 # https://peps.python.org/pep-0008/#indentation +OP_TOKEN_CODE = 54 # "OP" means operator token +NL_TOKEN_CODE = 61 # "NL" means new line +NEWLINE_TOKEN_CODE = 4 + + +class ArgType(Enum): + REGULAR = 1 + POS_ONLY = 2 + KW_ONLY = 3 + VARARG = 4 + KWARG = 5 + class Visitor(ast.NodeVisitor): - def __init__(self) -> None: + def __init__(self, tokens: List[tokenize.TokenInfo]) -> None: + self._tokens: List[tokenize.TokenInfo] = tokens self.violations: List[Tuple[int, int, str]] = [] def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - self._visit_func_args_or_class_bases(node, node.args.args, is_func=True) + sorted_args, arg_type_lookup, has_star = self._collect_func_args(node) + self._visit_func_args_or_class_bases( + node=node, + args_or_bases=sorted_args, + is_func=True, + arg_type_lookup=arg_type_lookup, + ) + if has_star: + self._visit_star_in_arg_list(node) def visit_ClassDef(self, node: ast.ClassDef) -> None: self._visit_func_args_or_class_bases(node, node.bases, is_func=False) + @classmethod + def _collect_func_args( + cls, + node: ast.FunctionDef, + ) -> Tuple[List, Dict[ast.arg, ArgType], bool]: + all_args: List[ast.arg] = [] + arg_type_lookup: Dict[ast.arg, ArgType] = {} + + has_star = False # it means there's a '*,' in the argument list + + if node.args.args: + all_args.extend(node.args.args) + for arg_ in node.args.args: + arg_type_lookup[arg_] = ArgType.REGULAR + + if node.args.posonlyargs: + has_star = True + all_args.extend(node.args.posonlyargs) + for arg_ in node.args.posonlyargs: + arg_type_lookup[arg_] = ArgType.POS_ONLY + + if node.args.kwonlyargs: + has_star = True + all_args.extend(node.args.kwonlyargs) + for arg_ in node.args.kwonlyargs: + arg_type_lookup[arg_] = ArgType.KW_ONLY + + if node.args.vararg is not None: + all_args.append(node.args.vararg) + arg_type_lookup[node.args.vararg] = ArgType.VARARG + + if node.args.kwarg is not None: + all_args.append(node.args.kwarg) + arg_type_lookup[node.args.kwarg] = ArgType.KWARG + + sorted_args = sorted(all_args, key=lambda x: x.lineno, reverse=False) + + return sorted_args, arg_type_lookup, has_star + def _visit_func_args_or_class_bases( self, node: Union[ast.FunctionDef, ast.ClassDef], args_or_bases: Union[List[ast.arg], List[ast.Name]], is_func: bool, + arg_type_lookup: Optional[Dict[ast.arg, ArgType]] = None, ) -> None: if is_func: code01 = IND101 @@ -42,41 +105,189 @@ def _visit_func_args_or_class_bases( code02 = IND202 if len(args_or_bases) > 0: - def_line_num = node.lineno + function_def_line_num = node.lineno def_col_offset = node.col_offset - if args_or_bases[0].lineno == def_line_num: + if args_or_bases[0].lineno == function_def_line_num: for item in args_or_bases[1:]: - if item.lineno != def_line_num: - self.violations.append( - (item.lineno, item.col_offset + 1, code02), - ) + if item.lineno != function_def_line_num: + arg_type = arg_type_lookup[item] if is_func else None + col_offset = self._calc_col_offset(item, arg_type) + self.violations.append((item.lineno, col_offset, code02)) for i, item in enumerate(args_or_bases): if i == 0: - prev_item_line_num = def_line_num + prev_item_line_num = function_def_line_num else: prev_item_line_num = args_or_bases[i - 1].lineno # Only enforce indentation when this arg is on a new line if item.lineno > prev_item_line_num: - if item.col_offset - def_col_offset != EXPECTED_INDENT: - self.violations.append( - (item.lineno, item.col_offset + 1, code01), - ) + arg_type = arg_type_lookup[item] if is_func else None + if self._not_expected_indent(item, def_col_offset, arg_type): + col_offset = self._calc_col_offset(item, arg_type) + self.violations.append((item.lineno, col_offset, code01)) self.generic_visit(node) + @classmethod + def _calc_col_offset( + cls, + item: Union[ast.arg, ast.Name], + arg_type: Optional[ArgType] = None, + ) -> int: + if isinstance(item, ast.Name): # this means base class + return item.col_offset + 1 + + arg_type = ArgType.REGULAR if arg_type is None else arg_type + return cls._calc_col_offset_for_func_args(item, arg_type) + + @classmethod + def _calc_col_offset_for_func_args( + cls, + arg_: ast.arg, + arg_type: ArgType, + ) -> int: + if arg_type in {ArgType.REGULAR, ArgType.POS_ONLY, ArgType.KW_ONLY}: + return arg_.col_offset + 1 + + if arg_type == ArgType.VARARG: + return arg_.col_offset + 1 - 1 # '-1' because of '*' before vararg + + if arg_type == ArgType.KWARG: + return arg_.col_offset + 1 - 2 # '-2' because of '**' before kwarg + + @classmethod + def _not_expected_indent( + cls, + item: Union[ast.arg, ast.Name], + def_col_offset: int, + arg_type: Optional[ArgType] = None, + ) -> bool: + if isinstance(item, ast.Name): # this means base class + return item.col_offset - def_col_offset != EXPECTED_INDENT + + arg_type = ArgType.REGULAR if arg_type is None else arg_type + if arg_type in {ArgType.REGULAR, ArgType.POS_ONLY, ArgType.KW_ONLY}: + expected_indent_ = EXPECTED_INDENT + elif arg_type == ArgType.VARARG: + expected_indent_ = EXPECTED_INDENT + 1 # because '*vararg' + elif arg_type == ArgType.KWARG: + expected_indent_ = EXPECTED_INDENT + 2 # because '**kwarg' + else: + # this branch can't be reached in theory + expected_indent_ = EXPECTED_INDENT + + return item.col_offset - def_col_offset != expected_indent_ + + def _visit_star_in_arg_list(self, node: ast.FunctionDef): + func_def_lineno = node.lineno + func_end_lineno = node.end_lineno + + # We skip the last 2 tokens because the last token from the tokenizer + # is always ENDMARKER (type 0) and '*' cannot be the 2nd to last token. + # And then we also skip the 0th token because that is always + # the ENCODING token (type 62). + for i in range(1, len(self._tokens) - 2): + this_token = self._tokens[i] + this_lineno = this_token.start[0] + this_col = this_token.start[1] + 1 + if func_def_lineno < this_lineno <= func_end_lineno: + if self._is_a_violation(node, self._tokens, i): + self.violations.append((this_lineno, this_col, IND101)) + + @classmethod + def _is_a_violation( + cls, + node: ast.FunctionDef, + tokens: List[tokenize.TokenInfo], + index: int, + ) -> bool: + prev_token = tokens[index - 1] + this_token = tokens[index] + next_token = tokens[index + 1] + next_next_token = tokens[index + 2] + + is_qualifying_star = ( + cls._is_star_comma(this=this_token, next=next_token) + or cls._is_star_newline_comma( + this=this_token, + next=next_token, + next_next=next_next_token, + ) + ) + + is_1st_symbol = cls._star_is_1st_non_empty_symbol_on_this_line( + prev=prev_token, + this=this_token, + ) + + not_expected_indent = ( + this_token.start[1] - node.col_offset != EXPECTED_INDENT + ) + + return is_qualifying_star and is_1st_symbol and not_expected_indent + + @classmethod + def _is_star_comma( + cls, + this: tokenize.TokenInfo, + next: tokenize.TokenInfo, + ) -> bool: + return cls._is_star(this) and cls._is_comma(next) + + @classmethod + def _is_star_newline_comma( + cls, + this: tokenize.TokenInfo, + next: tokenize.TokenInfo, + next_next: tokenize.TokenInfo, + ) -> bool: + return ( + cls._is_star(this) + and cls._is_newline(next) + and cls._is_comma(next_next) + ) + + @classmethod + def _is_star(cls, token: tokenize.TokenInfo) -> bool: + return token.type == OP_TOKEN_CODE and token.string == '*' + + @classmethod + def _is_comma(cls, token: tokenize.TokenInfo) -> bool: + return token.type == OP_TOKEN_CODE and token.string == ',' + + @classmethod + def _is_newline(cls, token: tokenize.TokenInfo) -> bool: + return ( + token.type in {NL_TOKEN_CODE, NEWLINE_TOKEN_CODE} + and token.string == '\n' + ) + + @classmethod + def _star_is_1st_non_empty_symbol_on_this_line( + cls, + prev: tokenize.TokenInfo, # the previous token + this: tokenize.TokenInfo, # it's expected to be '*' + ) -> bool: + return prev.start[0] < this.start[0] or prev.string.strip() == '' + class Plugin: name = __name__ version = importlib_metadata.version(__name__) - def __init__(self, tree: ast.AST) -> None: + def __init__( + self, + tree: ast.AST, + file_tokens: List[tokenize.TokenInfo] = None, + ) -> None: self._tree = tree + self._file_tokens = file_tokens def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: - visitor = Visitor() + visitor = Visitor(self._file_tokens) + visitor.visit(self._tree) for line, col, msg in visitor.violations: yield line, col, msg, type(self) diff --git a/setup.cfg b/setup.cfg index b3dc762..3f6fe4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = flake8_indent_in_def -version = 0.1.0 +version = 0.1.1 description = A flake8 plugin that enforces 8-space indentation in function/class definitions long_description = file: README.md long_description_content_type = text/markdown diff --git a/tests/not_ok_cases.py b/tests/not_ok_cases.py index f9a6f75..77e8ca1 100644 --- a/tests/not_ok_cases.py +++ b/tests/not_ok_cases.py @@ -279,6 +279,177 @@ def __init__(self, arg1, arg2, arg3, ], ) +# There's no case_5a for legacy reasons +case_5b_src = """ +def func5b( + arg5b1, + *arg5b2, + **arg5b3, +): + pass +""" +case_5b = ( + case_5b_src, + [ + (3, 5, IND101), + (4, 5, IND101), + (5, 5, IND101), + ], +) + + +case_5c_src = """ +def func5c(arg5c1, + arg5c2, + *arg5c3, + **arg5c4): + pass +""" +case_5c = [ + case_5c_src, + [ + (3, 12, IND101), (3, 12, IND102), + (4, 12, IND101), (4, 12, IND102), + (5, 12, IND101), (5, 12, IND102), + ], +] + + +case_5d_src = """ +def func5d(arg5d1, arg5d2, + *arg5d3, + **arg5d4): + pass +""" +case_5d = [ + case_5d_src, + [ + (3, 12, IND101), (3, 12, IND102), + (4, 12, IND101), (4, 12, IND102), + ], +] + + +case_5e_src = """ +def func5e(arg5e1, arg5e2, + *arg5e3, + **arg5e4, +): + pass +""" +case_5e = [ + case_5e_src, + [ + (3, 9, IND102), + (4, 9, IND102), + ], +] + + +case_6a_src = """ +def func6a( + arg6a1, + *, + arg6a2, + arg6a3, +): + pass +""" +case_6a = ( + case_6a_src, + [ + (3, 5, IND101), + (4, 5, IND101), + (5, 5, IND101), + (6, 5, IND101), + ], +) + + +case_6b_src = """ +def func6b( + arg6b1, + *, + arg6b2, + arg6b3, +): + pass +""" +case_6b = ( + case_6b_src, [(4, 5, IND101)], +) + + +case_6c_src = """ +def func6c( + arg6c1, + *, arg6c2, arg6c3, +): + pass +""" +case_6c = ( + case_6c_src, [(3, 5, IND101), (4, 5, IND101), (4, 8, IND101)], +) + + +case_6d_src = """ +def func6d( + arg6d1, *, + arg6d2, arg6d3, +): + pass +""" +case_6d = ( + case_6d_src, [(3, 5, IND101), (4, 5, IND101)], +) + + +case_6e_src = """ +def func6e(arg6e1, *, + arg6e2, arg6e3, +): + pass +""" +case_6e = ( + case_6e_src, [(3, 5, IND101), (3, 5, IND102), (3, 13, IND102)], +) + + +case_6f_src = """ +def func6f(*, + arg6f2, arg6f3, +): + pass +""" +case_6f = ( + case_6f_src, [(3, 5, IND101)], +) + + +case_6g_src = """ +def func6g( + *, + arg6g2, arg6g3, +): + pass +""" +case_6g = ( + case_6g_src, [(3, 5, IND101), (4, 5, IND101)], +) + + +case_6h_src = """ +def func6h( + *, + arg6h2, + arg6h3, +): + pass +""" +case_6h = ( + case_6h_src, [(3, 5, IND101), (4, 5, IND101), (5, 5, IND101)], +) + def collect_all_cases(): return ( @@ -305,4 +476,16 @@ def collect_all_cases(): case_4d, case_4e, case_4f, + case_5b, + case_5c, + case_5d, + case_5e, + case_6a, + case_6b, + case_6c, + case_6d, + case_6e, + case_6f, + case_6g, + case_6h, ) diff --git a/tests/ok_cases.py b/tests/ok_cases.py index 2111c52..29fe870 100644 --- a/tests/ok_cases.py +++ b/tests/ok_cases.py @@ -108,6 +108,22 @@ def __init__(self): """ +case_5a = """ +def func5a(*, arg1, arg2): + print(1) +""" + + +case_5b = """ +def func5b( + *, + arg2, + arg3, +): + print(1) +""" + + def collect_all_cases(): return ( case_0, @@ -123,4 +139,6 @@ def collect_all_cases(): case_4a, case_4b, case_4c, + case_5a, + case_5b, ) diff --git a/tests/test_flake8_indent_in_def.py b/tests/test_flake8_indent_in_def.py index f510121..0b9c4d4 100644 --- a/tests/test_flake8_indent_in_def.py +++ b/tests/test_flake8_indent_in_def.py @@ -1,8 +1,11 @@ import os import ast import sys +import tokenize + import pytest -from typing import Set +import tempfile +from typing import Set, List from flake8_indent_in_def import Plugin @@ -15,12 +18,21 @@ def _results(src_code: str) -> Set[str]: tree = ast.parse(src_code) - plugin = Plugin(tree) + plugin = Plugin(tree=tree, file_tokens=_tokenize_string(string=src_code)) return { f'{line}:{col} {msg}' for line, col, msg, _ in plugin.run() } +def _tokenize_string(string: str) -> List[tokenize.TokenInfo]: + with tempfile.TemporaryFile() as fp: + fp.write(str.encode(string)) + fp.seek(0) + tokens = list(tokenize.tokenize(fp.readline)) + + return tokens + + @pytest.mark.parametrize('src_code', ok_cases.collect_all_cases()) def test_ok_cases(src_code): assert _results(src_code) == set()