From 828749a91b985cb56fc0888c26939080be4f677d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 23 Mar 2023 06:50:02 +0000 Subject: [PATCH] style: enable `ruff` for tests --- pyproject.toml | 8 +- tests/helpers.py | 2 +- tests/test_ops.py | 133 +++++++++++++++++---------- tests/test_prefix_errors.py | 174 +++++++++++++++++++++++++++++------- tests/test_registry.py | 55 +++++++----- tests/test_treespec.py | 31 ++++--- tests/test_typing.py | 4 +- tests/test_utils.py | 3 +- 8 files changed, 295 insertions(+), 115 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f690e894..747706a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,6 @@ target-version = "py37" line-length = 100 show-source = true src = ["optree", "tests"] -extend-exclude = ["tests"] select = [ "E", "W", # pycodestyle "F", # pyflakes @@ -205,6 +204,13 @@ typing-modules = ["optree.typing"] "benchmark.py" = [ "PLW2901", # redefined-loop-name ] +"tests/**/*.py" = [ + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except + "SIM", # flake8-simplify + "PL", # pylint +] [tool.ruff.flake8-annotations] allow-star-arg-any = true diff --git a/tests/helpers.py b/tests/helpers.py index c13d5d6b..d19cfbec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -74,7 +74,7 @@ def __repr__(self): @optree.register_pytree_node_class( - namespace=optree.registry.__GLOBAL_NAMESPACE # pylint: disable=protected-access + namespace=optree.registry.__GLOBAL_NAMESPACE, # pylint: disable=protected-access ) class Vector2D: def __init__(self, x, y): diff --git a/tests/test_ops.py b/tests/test_ops.py index 127701dc..271538cf 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -46,46 +46,50 @@ def dummy_func(*args, **kwargs): # pylint: disable=unused-argument dummy_partial_func = functools.partial(dummy_func, a=1) -def is_tuple(t): - return isinstance(t, tuple) +def is_tuple(tup): + return isinstance(tup, tuple) -def is_list(l): - return isinstance(l, list) +def is_list(lst): + return isinstance(lst, list) -def is_none(n): - return n is None +def is_none(none): + return none is None -def always(o): # pylint: disable=unused-argument +def always(obj): # pylint: disable=unused-argument return True -def never(o): # pylint: disable=unused-argument +def never(obj): # pylint: disable=unused-argument return False def test_max_depth(): - l = [1] + lst = [1] for _ in range(optree.MAX_RECURSION_DEPTH - 1): - l = [l] - optree.tree_flatten(l) - optree.tree_flatten_with_path(l) + lst = [lst] + optree.tree_flatten(lst) + optree.tree_flatten_with_path(lst) - l = [l] + lst = [lst] with pytest.raises( - RecursionError, match='Maximum recursion depth exceeded during flattening the tree.' + RecursionError, + match='Maximum recursion depth exceeded during flattening the tree.', ): - optree.tree_flatten(l) + optree.tree_flatten(lst) with pytest.raises( - RecursionError, match='Maximum recursion depth exceeded during flattening the tree.' + RecursionError, + match='Maximum recursion depth exceeded during flattening the tree.', ): - optree.tree_flatten_with_path(l) + optree.tree_flatten_with_path(lst) @parametrize( - tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'] + tree=list(TREES + LEAVES), + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], ) def test_round_trip(tree, none_is_leaf, namespace): leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) @@ -94,7 +98,9 @@ def test_round_trip(tree, none_is_leaf, namespace): @parametrize( - tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'] + tree=list(TREES + LEAVES), + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], ) def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace): _, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) @@ -234,7 +240,7 @@ def f_leaf(leaf): def test_flatten_up_to(): _, treespec = optree.tree_flatten([(1, 2), None, CustomTuple(foo=3, bar=7)]) subtrees = treespec.flatten_up_to( - [({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)] + [({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)], ) assert subtrees == [{'foo': 7}, (3, 4), (11, 9), None] @@ -242,7 +248,7 @@ def test_flatten_up_to(): def test_flatten_up_to_none_is_leaf(): _, treespec = optree.tree_flatten([(1, 2), None, CustomTuple(foo=3, bar=7)], none_is_leaf=True) subtrees = treespec.flatten_up_to( - [({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)] + [({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)], ) assert subtrees == [{'foo': 7}, (3, 4), None, (11, 9), None] @@ -251,7 +257,7 @@ def test_flatten_up_to_none_is_leaf(): leaves_fn=[ optree.tree_leaves, lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[0], - ] + ], ) def test_flatten_is_leaf(leaves_fn): x = [(1, 2), (3, 4), (5, 6)] @@ -277,7 +283,7 @@ def test_flatten_is_leaf(leaves_fn): structure_fn=[ optree.tree_structure, lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[1], - ] + ], ) def test_structure_is_leaf(structure_fn): x = [(1, 2), (3, 4), (5, 6)] @@ -300,8 +306,8 @@ def test_structure_is_leaf(structure_fn): itertools.chain( zip(TREES, TREE_PATHS[False], itertools.repeat(False)), zip(TREES, TREE_PATHS[True], itertools.repeat(True)), - ) - ) + ), + ), ) def test_paths(data): tree, expected_paths, none_is_leaf = data @@ -332,7 +338,10 @@ def test_paths(data): ) def test_round_trip_is_leaf(tree, is_leaf, none_is_leaf, namespace): subtrees, treespec = optree.tree_flatten( - tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace + tree, + is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) actual = optree.tree_unflatten(treespec, subtrees) assert actual == tree @@ -506,7 +515,11 @@ def test_tree_map_with_is_leaf_none_is_leaf(): x = ((1, 2, None), [3, 4, 5]) y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) out = optree.tree_map( - lambda *xs: tuple(xs), x, y, is_leaf=lambda n: isinstance(n, list), none_is_leaf=True + lambda *xs: tuple(xs), + x, + y, + is_leaf=lambda n: isinstance(n, list), + none_is_leaf=True, ) assert out == (((1, [3]), (2, None), (None, 4)), (([3, 4, 5], ({'foo': 'bar'}, 7, [5, 6])))) @@ -671,7 +684,9 @@ def test_tree_transpose(tree): return with pytest.raises(ValueError, match='Tree structures must have the same none_is_leaf value.'): optree.tree_transpose( - outer_treespec, optree.tree_structure([1, 1, 1], none_is_leaf=True), nested + outer_treespec, + optree.tree_structure([1, 1, 1], none_is_leaf=True), + nested, ) actual = optree.tree_transpose(outer_treespec, inner_treespec, nested) assert actual == [tree, tree, tree] @@ -698,7 +713,9 @@ def test_tree_transpose_with_custom_object(): inner_treespec = optree.tree_structure([1, 2]) expected = [FlatCache({'a': 3, 'b': 5}), FlatCache({'a': 4, 'b': 6})] actual = optree.tree_transpose( - outer_treespec, inner_treespec, FlatCache({'a': [3, 4], 'b': [5, 6]}) + outer_treespec, + inner_treespec, + FlatCache({'a': [3, 4], 'b': [5, 6]}), ) assert actual == expected @@ -706,13 +723,14 @@ def test_tree_transpose_with_custom_object(): def test_tree_transpose_with_custom_namespace(): outer_treespec = optree.tree_structure(MyAnotherDict({'a': 1, 'b': 2}), namespace='namespace') inner_treespec = optree.tree_structure( - MyAnotherDict({'c': 1, 'd': 2, 'e': 3}), namespace='namespace' + MyAnotherDict({'c': 1, 'd': 2, 'e': 3}), + namespace='namespace', ) nested = MyAnotherDict( { 'a': MyAnotherDict({'c': 1, 'd': 2, 'e': 3}), 'b': MyAnotherDict({'c': 4, 'd': 5, 'e': 6}), - } + }, ) actual = optree.tree_transpose(outer_treespec, inner_treespec, nested) assert actual == MyAnotherDict( @@ -720,7 +738,7 @@ def test_tree_transpose_with_custom_namespace(): 'c': MyAnotherDict({'a': 1, 'b': 4}), 'd': MyAnotherDict({'a': 2, 'b': 5}), 'e': MyAnotherDict({'a': 3, 'b': 6}), - } + }, ) @@ -731,20 +749,22 @@ class MyExtraDict(MyAnotherDict): outer_treespec = optree.tree_structure(MyAnotherDict({'a': 1, 'b': 2}), namespace='namespace') inner_treespec = optree.tree_structure( - MyExtraDict({'c': 1, 'd': 2, 'e': 3}), namespace='subnamespace' + MyExtraDict({'c': 1, 'd': 2, 'e': 3}), + namespace='subnamespace', ) nested = MyAnotherDict( { 'a': MyExtraDict({'c': 1, 'd': 2, 'e': 3}), 'b': MyExtraDict({'c': 4, 'd': 5, 'e': 6}), - } + }, ) with pytest.raises(ValueError, match='Tree structures must have the same namespace.'): optree.tree_transpose(outer_treespec, inner_treespec, nested) optree.register_pytree_node_class(MyExtraDict, namespace='namespace') inner_treespec = optree.tree_structure( - MyExtraDict({'c': 1, 'd': 2, 'e': 3}), namespace='namespace' + MyExtraDict({'c': 1, 'd': 2, 'e': 3}), + namespace='namespace', ) actual = optree.tree_transpose(outer_treespec, inner_treespec, nested) assert actual == MyExtraDict( @@ -752,7 +772,7 @@ class MyExtraDict(MyAnotherDict): 'c': MyAnotherDict({'a': 1, 'b': 4}), 'd': MyAnotherDict({'a': 2, 'b': 5}), 'e': MyAnotherDict({'a': 3, 'b': 6}), - } + }, ) @@ -760,7 +780,8 @@ def test_tree_broadcast_prefix(): assert optree.tree_broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1] assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3] with pytest.raises( - ValueError, match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].') + ValueError, + match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'), ): optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, (3, 3)] @@ -770,7 +791,9 @@ def test_tree_broadcast_prefix(): {'a': 3, 'b': 3, 'c': (None, 3)}, ] assert optree.tree_broadcast_prefix( - [1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True + [1, 2, 3], + [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], + none_is_leaf=True, ) == [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}] @@ -778,7 +801,8 @@ def test_broadcast_prefix(): assert optree.broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1] assert optree.broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3] with pytest.raises( - ValueError, match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].') + ValueError, + match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'), ): optree.broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) assert optree.broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, 3, 3] @@ -790,7 +814,9 @@ def test_broadcast_prefix(): 3, ] assert optree.broadcast_prefix( - [1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True + [1, 2, 3], + [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], + none_is_leaf=True, ) == [1, 2, 3, 3, 3, 3] @@ -810,13 +836,18 @@ def test_tree_reduce(): assert optree.tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) == 3 assert ( optree.tree_reduce( - lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True + lambda x, y: x and y, + {'x': 1, 'y': (2, None), 'z': 3}, + none_is_leaf=True, ) is None ) assert ( optree.tree_reduce( - lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, False, none_is_leaf=True + lambda x, y: x and y, + {'x': 1, 'y': (2, None), 'z': 3}, + False, + none_is_leaf=True, ) is False ) @@ -826,13 +857,16 @@ def test_tree_sum(): assert optree.tree_sum({'x': 1, 'y': (2, 3)}) == 6 assert optree.tree_sum({'x': 1, 'y': (2, None), 'z': 3}) == 6 with pytest.raises( - TypeError, match=re.escape("unsupported operand type(s) for +: 'int' and 'NoneType'") + TypeError, + match=re.escape("unsupported operand type(s) for +: 'int' and 'NoneType'"), ): optree.tree_sum({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) assert optree.tree_sum({'x': 'a', 'y': ('b', None), 'z': 'c'}, start='') == 'abc' assert optree.tree_sum({'x': b'a', 'y': (b'b', None), 'z': b'c'}, start=b'') == b'abc' assert optree.tree_sum( - {'x': [1], 'y': ([2], [None]), 'z': [3]}, start=[], is_leaf=lambda x: isinstance(x, list) + {'x': [1], 'y': ([2], [None]), 'z': [3]}, + start=[], + is_leaf=lambda x: isinstance(x, list), ) == [1, 2, None, 3] @@ -893,7 +927,7 @@ def test_tree_any(): @parametrize(tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace']) -def test_flatten_one_level(tree, none_is_leaf, namespace): +def test_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901 stack = [tree] actual_leaves = [] expected_leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) @@ -902,7 +936,7 @@ def test_flatten_one_level(tree, none_is_leaf, namespace): counter = Counter() expected_children, one_level_treespec = optree.tree_flatten( node, - is_leaf=lambda x: counter.increment() > 1, + is_leaf=lambda x: counter.increment() > 1, # noqa: B023 none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -910,13 +944,16 @@ def test_flatten_one_level(tree, none_is_leaf, namespace): if one_level_treespec.is_leaf(): assert expected_children == [node] with pytest.raises( - ValueError, match=re.escape(f'Cannot flatten leaf-type: {node_type}.') + ValueError, + match=re.escape(f'Cannot flatten leaf-type: {node_type}.'), ): optree.ops.flatten_one_level(node, none_is_leaf=none_is_leaf, namespace=namespace) actual_leaves.append(node) else: children, metadata, entries = optree.ops.flatten_one_level( - node, none_is_leaf=none_is_leaf, namespace=namespace + node, + none_is_leaf=none_is_leaf, + namespace=namespace, ) assert children == expected_children if node_type in (list, tuple, type(None)): diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index 4e1d5ce4..f3142d00 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -16,6 +16,7 @@ # pylint: disable=missing-function-docstring,invalid-name,implicit-str-concat import re +import textwrap from collections import OrderedDict, defaultdict, deque import pytest @@ -40,7 +41,12 @@ def test_different_types(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different types at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -57,7 +63,12 @@ def test_different_types(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different types at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -77,7 +88,14 @@ def test_different_types_nested(): optree.tree_map_(lambda x, y: None, lhs, rhs) (e,) = optree.prefix_errors(lhs, rhs) - expected = re.escape('pytree structure error: different types at key path\n' ' in_axes[0]') + expected = re.escape( + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes[0] + """, + ).strip(), + ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -88,10 +106,24 @@ def test_different_types_multiple(): optree.tree_map_(lambda x, y: None, lhs, rhs) e1, e2 = optree.prefix_errors(lhs, rhs) - expected = re.escape('pytree structure error: different types at key path\n' ' in_axes[0]') + expected = re.escape( + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes[0] + """, + ).strip(), + ) with pytest.raises(ValueError, match=expected): raise e1('in_axes') - expected = re.escape('pytree structure error: different types at key path\n' ' in_axes[1]') + expected = re.escape( + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes[1] + """, + ).strip(), + ) with pytest.raises(ValueError, match=expected): raise e2('in_axes') @@ -99,14 +131,19 @@ def test_different_types_multiple(): def test_different_num_children(): lhs, rhs = (1,), (2, 3) with pytest.raises( - ValueError, match=r'tuple arity mismatch; expected: \d+, got: \d+; tuple: .*\.' + ValueError, + match=r'tuple arity mismatch; expected: \d+, got: \d+; tuple: .*\.', ): optree.tree_map_(lambda x, y: None, lhs, rhs) (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different numbers of pytree children at key path\n' - ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different numbers of pytree children at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -115,14 +152,19 @@ def test_different_num_children(): def test_different_num_children_nested(): lhs, rhs = [[1]], [[2, 3]] with pytest.raises( - ValueError, match=r'list arity mismatch; expected: \d+, got: \d+; list: .*\.' + ValueError, + match=r'list arity mismatch; expected: \d+, got: \d+; list: .*\.', ): optree.tree_map_(lambda x, y: None, lhs, rhs) (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different numbers of pytree children at key path\n' - ' in_axes[0]' + textwrap.dedent( + """ + pytree structure error: different numbers of pytree children at key path + in_axes[0] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -131,20 +173,29 @@ def test_different_num_children_nested(): def test_different_num_children_multiple(): lhs, rhs = [[1], [2]], [[3, 4], [5, 6]] with pytest.raises( - ValueError, match=r'list arity mismatch; expected: \d+, got: \d+; list: .*\.' + ValueError, + match=r'list arity mismatch; expected: \d+, got: \d+; list: .*\.', ): optree.tree_map_(lambda x, y: None, lhs, rhs) e1, e2 = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different numbers of pytree children at key path\n' - ' in_axes[0]' + textwrap.dedent( + """ + pytree structure error: different numbers of pytree children at key path + in_axes[0] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e1('in_axes') expected = re.escape( - 'pytree structure error: different numbers of pytree children at key path\n' - ' in_axes[1]' + textwrap.dedent( + """ + pytree structure error: different numbers of pytree children at key path + in_axes[1] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e2('in_axes') @@ -160,7 +211,12 @@ def test_different_metadata(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -174,7 +230,12 @@ def test_different_metadata(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -208,7 +269,12 @@ def test_different_metadata(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different pytree metadata at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different pytree metadata at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -224,7 +290,12 @@ def test_different_metadata_nested(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes[0]' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes[0] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -240,12 +311,22 @@ def test_different_metadata_multiple(): e1, e2 = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes[0]' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes[0] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e1('in_axes') expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes[1]' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes[1] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e2('in_axes') @@ -258,7 +339,12 @@ def test_namedtuple(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different types at key path\n' ' in_axes.bar[1]' + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes.bar[1] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -271,7 +357,12 @@ def test_structseq(): (e,) = optree.prefix_errors(lhs, rhs) expected = re.escape( - 'pytree structure error: different types at key path\n' ' in_axes.tm_mon[1]' + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes.tm_mon[1] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -280,7 +371,12 @@ def test_structseq(): def test_fallback_keypath(): (e,) = optree.prefix_errors(Vector2D(1, [2]), Vector2D(3, 4)) expected = re.escape( - 'pytree structure error: different types at key path\n' ' in_axes[]' + textwrap.dedent( + """ + pytree structure error: different types at key path + in_axes[] + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -298,15 +394,24 @@ def test_no_errors(): def test_different_structure_no_children(): (e,) = optree.prefix_errors((), ([],)) expected = re.escape( - 'pytree structure error: different numbers of pytree children at key path\n' - ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different numbers of pytree children at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') (e,) = optree.prefix_errors({}, {'a': []}) expected = re.escape( - 'pytree structure error: different pytree keys at key path\n' ' in_axes tree root' + textwrap.dedent( + """ + pytree structure error: different pytree keys at key path + in_axes tree root + """, + ).strip(), ) with pytest.raises(ValueError, match=expected): raise e('in_axes') @@ -334,11 +439,13 @@ def test_key_path(): 1 + sequence_key_path with pytest.raises( - TypeError, match=re.escape("unsupported operand type(s) for +: 'KeyPath' and 'int'") + TypeError, + match=re.escape("unsupported operand type(s) for +: 'KeyPath' and 'int'"), ): root + 1 with pytest.raises( - TypeError, match=re.escape("unsupported operand type(s) for +: 'int' and 'KeyPath'") + TypeError, + match=re.escape("unsupported operand type(s) for +: 'int' and 'KeyPath'"), ): 1 + root @@ -368,7 +475,12 @@ def test_key_path(): assert (namedtuple_key_path + fallback_key_path).pprint() == '.attr[]' assert (fallback_key_path + namedtuple_key_path).pprint() == '[].attr' assert sequence_key_path + dict_key_path + namedtuple_key_path + fallback_key_path == KeyPath( - (sequence_key_path, dict_key_path, namedtuple_key_path, fallback_key_path) + ( + sequence_key_path, + dict_key_path, + namedtuple_key_path, + fallback_key_path, + ), ) assert ( sequence_key_path + dict_key_path + namedtuple_key_path + fallback_key_path diff --git a/tests/test_registry.py b/tests/test_registry.py index 787ca887..c62848bd 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -25,7 +25,8 @@ def test_register_pytree_node_class_with_no_namespace(): with pytest.raises( - ValueError, match='Must specify `namespace` when the first argument is a class.' + ValueError, + match='Must specify `namespace` when the first argument is a class.', ): @optree.register_pytree_node_class @@ -40,7 +41,8 @@ def tree_unflatten(cls, metadata, children): def test_register_pytree_node_class_with_duplicate_namespace(): with pytest.raises( - ValueError, match='Cannot specify `namespace` when the first argument is a string.' + ValueError, + match='Cannot specify `namespace` when the first argument is a string.', ): @optree.register_pytree_node_class('mylist', namespace='mylist') @@ -62,7 +64,10 @@ def func(): with pytest.raises(TypeError, match='Expected a class'): optree.register_pytree_node( - 1, lambda s: (sorted(s), None, None), lambda _, s: set(s), namespace='non-class' + 1, + lambda s: (sorted(s), None, None), + lambda _, s: set(s), + namespace='non-class', ) @@ -86,11 +91,13 @@ def tree_unflatten(cls, metadata, children): return cls(children) with pytest.raises( - ValueError, match=r"PyTree type.*is already registered in namespace 'mylist1'\." + ValueError, + match=r"PyTree type.*is already registered in namespace 'mylist1'\.", ): optree.register_pytree_node_class(MyList1, namespace='mylist1') with pytest.raises( - ValueError, match=r"PyTree type.*is already registered in namespace 'mylist2'\." + ValueError, + match=r"PyTree type.*is already registered in namespace 'mylist2'\.", ): optree.register_pytree_node_class(MyList2, namespace='mylist2') @@ -101,7 +108,7 @@ def test_register_pytree_node_with_invalid_namespace(): with pytest.raises(TypeError, match='The namespace must be a string'): @optree.register_pytree_node_class(namespace=1) - class MyList(UserList): + class MyList1(UserList): def tree_flatten(self): return self.data, None, None @@ -112,7 +119,7 @@ def tree_unflatten(cls, metadata, children): with pytest.raises(ValueError, match='The namespace cannot be an empty string.'): @optree.register_pytree_node_class('') - class MyList(UserList): + class MyList2(UserList): def tree_flatten(self): return self.data, None, None @@ -123,7 +130,7 @@ def tree_unflatten(cls, metadata, children): with pytest.raises(ValueError, match='The namespace cannot be an empty string.'): @optree.register_pytree_node_class(namespace='') - class MyList(UserList): + class MyList3(UserList): def tree_flatten(self): return self.data, None, None @@ -133,12 +140,18 @@ def tree_unflatten(cls, metadata, children): with pytest.raises(TypeError, match='The namespace must be a string'): optree.register_pytree_node( - set, lambda s: (sorted(s), None, None), lambda _, s: set(s), namespace=1 + set, + lambda s: (sorted(s), None, None), + lambda _, s: set(s), + namespace=1, ) with pytest.raises(ValueError, match='The namespace cannot be an empty string.'): optree.register_pytree_node( - set, lambda s: (sorted(s), None, None), lambda _, s: set(s), namespace='' + set, + lambda s: (sorted(s), None, None), + lambda _, s: set(s), + namespace='', ) @@ -171,8 +184,8 @@ def test_register_pytree_node_duplicate_builtin_namespace(): ): optree.register_pytree_node( list, - lambda l: (l, None, None), - lambda _, l: l, + lambda lst: (lst, None, None), + lambda _, lst: lst, namespace=optree.registry.__GLOBAL_NAMESPACE, ) with pytest.raises( @@ -181,8 +194,8 @@ def test_register_pytree_node_duplicate_builtin_namespace(): ): optree.register_pytree_node( list, - lambda l: (l, None, None), - lambda _, l: l, + lambda lst: (lst, None, None), + lambda _, lst: lst, namespace='list', ) @@ -194,7 +207,7 @@ def test_register_pytree_node_namedtuple(): match=re.escape( r"PyTree type is a subclass of `collections.namedtuple`, " r'which is already registered in the global namespace. ' - r'Override it with custom flatten/unflatten functions.' + r'Override it with custom flatten/unflatten functions.', ), ): optree.register_pytree_node( @@ -206,7 +219,7 @@ def test_register_pytree_node_namedtuple(): with pytest.raises( ValueError, match=re.escape( - r"PyTree type is already registered in the global namespace." + r"PyTree type is already registered in the global namespace.", ), ): optree.register_pytree_node( @@ -228,7 +241,7 @@ def test_register_pytree_node_namedtuple(): match=re.escape( r"PyTree type is a subclass of `collections.namedtuple`, " r'which is already registered in the global namespace. ' - r"Override it with custom flatten/unflatten functions in namespace 'mytuple'." + r"Override it with custom flatten/unflatten functions in namespace 'mytuple'.", ), ): optree.register_pytree_node( @@ -326,13 +339,13 @@ def tree_unflatten(cls, metadata, children): def test_pytree_node_registry_get(): handler = optree.register_pytree_node.get(list) assert handler is not None - l = [1, 2, 3] - assert handler.to_iterable(l)[:2] == (l, None) + lst = [1, 2, 3] + assert handler.to_iterable(lst)[:2] == (lst, None) handler = optree.register_pytree_node.get(list, namespace='any') assert handler is not None - l = [1, 2, 3] - assert handler.to_iterable(l)[:2] == (l, None) + lst = [1, 2, 3] + assert handler.to_iterable(lst)[:2] == (lst, None) handler = optree.register_pytree_node.get(set) assert handler is None diff --git a/tests/test_treespec.py b/tests/test_treespec.py index b0bd9134..db889149 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -105,8 +105,8 @@ def build_subtree(x): itertools.chain( zip(TREES, TREE_STRINGS[False], itertools.repeat(False)), zip(TREES, TREE_STRINGS[True], itertools.repeat(True)), - ) - ) + ), + ), ) def test_treespec_string_representation(data): tree, correct_string, none_is_leaf = data @@ -122,7 +122,9 @@ def test_with_namespace(): assert leaves == [tree] assert str(treespec) == ('PyTreeSpec(*)') paths, leaves, treespec = optree.tree_flatten_with_path( - tree, none_is_leaf=False, namespace=namespace + tree, + none_is_leaf=False, + namespace=namespace, ) assert paths == [()] assert leaves == [tree] @@ -133,7 +135,9 @@ def test_with_namespace(): assert leaves == [tree] assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)') paths, leaves, treespec = optree.tree_flatten_with_path( - tree, none_is_leaf=True, namespace=namespace + tree, + none_is_leaf=True, + namespace=namespace, ) assert paths == [()] assert leaves == [tree] @@ -145,7 +149,9 @@ def test_with_namespace(): assert leaves == [2, 1, 101] assert str(treespec) == expected_string paths, leaves, treespec = optree.tree_flatten_with_path( - tree, none_is_leaf=False, namespace='namespace' + tree, + none_is_leaf=False, + namespace='namespace', ) assert paths == [('foo', 'b'), ('foo', 'a'), ('baz',)] assert leaves == [2, 1, 101] @@ -157,7 +163,9 @@ def test_with_namespace(): assert leaves == [None, 2, 1, 101] assert str(treespec) == expected_string paths, leaves, treespec = optree.tree_flatten_with_path( - tree, none_is_leaf=True, namespace='namespace' + tree, + none_is_leaf=True, + namespace='namespace', ) assert paths == [('foo', 'c'), ('foo', 'b'), ('foo', 'a'), ('baz',)] assert leaves == [None, 2, 1, 101] @@ -181,7 +189,7 @@ def test_treespec_pickle_round_trip(tree, none_is_leaf, namespace): assert actual == expected if expected.type is dict or expected.type is defaultdict: assert list(optree.tree_unflatten(actual, range(len(actual)))) == list( - optree.tree_unflatten(expected, range(len(expected))) + optree.tree_unflatten(expected, range(len(expected))), ) @@ -402,13 +410,15 @@ def test_treespec_tuple_from_children(none_is_leaf): ) def test_treespec_tuple_compares_equal(none_is_leaf): actual = optree.treespec_tuple( - (optree.tree_structure(3, none_is_leaf=none_is_leaf),), none_is_leaf=none_is_leaf + (optree.tree_structure(3, none_is_leaf=none_is_leaf),), + none_is_leaf=none_is_leaf, ) expected = optree.tree_structure((3,), none_is_leaf=none_is_leaf) assert actual == expected actual = optree.treespec_tuple( - (optree.tree_structure(None, none_is_leaf=none_is_leaf),), none_is_leaf=none_is_leaf + (optree.tree_structure(None, none_is_leaf=none_is_leaf),), + none_is_leaf=none_is_leaf, ) expected = optree.tree_structure((None,), none_is_leaf=none_is_leaf) assert actual == expected @@ -460,7 +470,8 @@ def test_treespec_leaf_none(): assert optree.treespec_leaf(none_is_leaf=True) == optree.tree_structure(1, none_is_leaf=True) assert optree.treespec_leaf(none_is_leaf=True) == optree.tree_structure(None, none_is_leaf=True) assert optree.treespec_leaf(none_is_leaf=True) != optree.tree_structure( - None, none_is_leaf=False + None, + none_is_leaf=False, ) assert optree.treespec_leaf(none_is_leaf=True) == optree.treespec_none(none_is_leaf=True) assert optree.treespec_leaf(none_is_leaf=True) != optree.treespec_none(none_is_leaf=False) diff --git a/tests/test_typing.py b/tests/test_typing.py index fafe1355..6f9be56a 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -124,7 +124,7 @@ def test_namedtuple_fields(): TypeError, match=re.escape( r'Expected an instance of collections.namedtuple type, ' - r'got time.struct_time(tm_year=0, tm_mon=1, tm_mday=2, tm_hour=3, tm_min=4, tm_sec=5, tm_wday=6, tm_yday=7, tm_isdst=8).' + r'got time.struct_time(tm_year=0, tm_mon=1, tm_mday=2, tm_hour=3, tm_min=4, tm_sec=5, tm_wday=6, tm_yday=7, tm_isdst=8).', ), ): optree.namedtuple_fields(time.struct_time(range(9))) @@ -221,7 +221,7 @@ def test_structseq_fields(): with pytest.raises( TypeError, match=re.escape( - r'Expected an instance of PyStructSequence type, got CustomTuple(foo=1, bar=2).' + r'Expected an instance of PyStructSequence type, got CustomTuple(foo=1, bar=2).', ), ): optree.structseq_fields(CustomTuple(1, 2)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cf791f9..368d9db4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,7 +29,8 @@ def test_total_order_sorted(): assert total_order_sorted([1, 5, 4, '20', '3']) == [1, 4, 5, '20', '3'] assert total_order_sorted([1, 5, 4.5, '20', '3']) == [4.5, 1, 5, '20', '3'] assert total_order_sorted( - {1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(), key=lambda kv: kv[0] + {1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(), + key=lambda kv: kv[0], ) == [(4.5, 3), (1, 1), (5, 2), ('20', 4), ('3', 5)] class NonSortable: