Skip to content

Commit

Permalink
style: enable ruff for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Mar 23, 2023
1 parent c3b6d77 commit 828749a
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 115 deletions.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ target-version = "py37"
line-length = 100
show-source = true
src = ["optree", "tests"]
extend-exclude = ["tests"]
select = [
"E", "W", # pycodestyle
"F", # pyflakes
Expand Down Expand Up @@ -205,6 +204,13 @@ typing-modules = ["optree.typing"]
"benchmark.py" = [
"PLW2901", # redefined-loop-name
]
"tests/**/*.py" = [
"ANN", # flake8-annotations
"S", # flake8-bandit
"BLE", # flake8-blind-except
"SIM", # flake8-simplify
"PL", # pylint
]

[tool.ruff.flake8-annotations]
allow-star-arg-any = true
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __repr__(self):


@optree.register_pytree_node_class(
namespace=optree.registry.__GLOBAL_NAMESPACE # pylint: disable=protected-access
namespace=optree.registry.__GLOBAL_NAMESPACE, # pylint: disable=protected-access
)
class Vector2D:
def __init__(self, x, y):
Expand Down
133 changes: 85 additions & 48 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,46 +46,50 @@ def dummy_func(*args, **kwargs): # pylint: disable=unused-argument
dummy_partial_func = functools.partial(dummy_func, a=1)


def is_tuple(t):
return isinstance(t, tuple)
def is_tuple(tup):
return isinstance(tup, tuple)


def is_list(l):
return isinstance(l, list)
def is_list(lst):
return isinstance(lst, list)


def is_none(n):
return n is None
def is_none(none):
return none is None


def always(o): # pylint: disable=unused-argument
def always(obj): # pylint: disable=unused-argument
return True


def never(o): # pylint: disable=unused-argument
def never(obj): # pylint: disable=unused-argument
return False


def test_max_depth():
l = [1]
lst = [1]
for _ in range(optree.MAX_RECURSION_DEPTH - 1):
l = [l]
optree.tree_flatten(l)
optree.tree_flatten_with_path(l)
lst = [lst]
optree.tree_flatten(lst)
optree.tree_flatten_with_path(lst)

l = [l]
lst = [lst]
with pytest.raises(
RecursionError, match='Maximum recursion depth exceeded during flattening the tree.'
RecursionError,
match='Maximum recursion depth exceeded during flattening the tree.',
):
optree.tree_flatten(l)
optree.tree_flatten(lst)
with pytest.raises(
RecursionError, match='Maximum recursion depth exceeded during flattening the tree.'
RecursionError,
match='Maximum recursion depth exceeded during flattening the tree.',
):
optree.tree_flatten_with_path(l)
optree.tree_flatten_with_path(lst)


@parametrize(
tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace']
tree=list(TREES + LEAVES),
none_is_leaf=[False, True],
namespace=['', 'undefined', 'namespace'],
)
def test_round_trip(tree, none_is_leaf, namespace):
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace)
Expand All @@ -94,7 +98,9 @@ def test_round_trip(tree, none_is_leaf, namespace):


@parametrize(
tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace']
tree=list(TREES + LEAVES),
none_is_leaf=[False, True],
namespace=['', 'undefined', 'namespace'],
)
def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace):
_, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace)
Expand Down Expand Up @@ -234,15 +240,15 @@ def f_leaf(leaf):
def test_flatten_up_to():
_, treespec = optree.tree_flatten([(1, 2), None, CustomTuple(foo=3, bar=7)])
subtrees = treespec.flatten_up_to(
[({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)]
[({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)],
)
assert subtrees == [{'foo': 7}, (3, 4), (11, 9), None]


def test_flatten_up_to_none_is_leaf():
_, treespec = optree.tree_flatten([(1, 2), None, CustomTuple(foo=3, bar=7)], none_is_leaf=True)
subtrees = treespec.flatten_up_to(
[({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)]
[({'foo': 7}, (3, 4)), None, CustomTuple(foo=(11, 9), bar=None)],
)
assert subtrees == [{'foo': 7}, (3, 4), None, (11, 9), None]

Expand All @@ -251,7 +257,7 @@ def test_flatten_up_to_none_is_leaf():
leaves_fn=[
optree.tree_leaves,
lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[0],
]
],
)
def test_flatten_is_leaf(leaves_fn):
x = [(1, 2), (3, 4), (5, 6)]
Expand All @@ -277,7 +283,7 @@ def test_flatten_is_leaf(leaves_fn):
structure_fn=[
optree.tree_structure,
lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[1],
]
],
)
def test_structure_is_leaf(structure_fn):
x = [(1, 2), (3, 4), (5, 6)]
Expand All @@ -300,8 +306,8 @@ def test_structure_is_leaf(structure_fn):
itertools.chain(
zip(TREES, TREE_PATHS[False], itertools.repeat(False)),
zip(TREES, TREE_PATHS[True], itertools.repeat(True)),
)
)
),
),
)
def test_paths(data):
tree, expected_paths, none_is_leaf = data
Expand Down Expand Up @@ -332,7 +338,10 @@ def test_paths(data):
)
def test_round_trip_is_leaf(tree, is_leaf, none_is_leaf, namespace):
subtrees, treespec = optree.tree_flatten(
tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace
tree,
is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
actual = optree.tree_unflatten(treespec, subtrees)
assert actual == tree
Expand Down Expand Up @@ -506,7 +515,11 @@ def test_tree_map_with_is_leaf_none_is_leaf():
x = ((1, 2, None), [3, 4, 5])
y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6]))
out = optree.tree_map(
lambda *xs: tuple(xs), x, y, is_leaf=lambda n: isinstance(n, list), none_is_leaf=True
lambda *xs: tuple(xs),
x,
y,
is_leaf=lambda n: isinstance(n, list),
none_is_leaf=True,
)
assert out == (((1, [3]), (2, None), (None, 4)), (([3, 4, 5], ({'foo': 'bar'}, 7, [5, 6]))))

Expand Down Expand Up @@ -671,7 +684,9 @@ def test_tree_transpose(tree):
return
with pytest.raises(ValueError, match='Tree structures must have the same none_is_leaf value.'):
optree.tree_transpose(
outer_treespec, optree.tree_structure([1, 1, 1], none_is_leaf=True), nested
outer_treespec,
optree.tree_structure([1, 1, 1], none_is_leaf=True),
nested,
)
actual = optree.tree_transpose(outer_treespec, inner_treespec, nested)
assert actual == [tree, tree, tree]
Expand All @@ -698,29 +713,32 @@ def test_tree_transpose_with_custom_object():
inner_treespec = optree.tree_structure([1, 2])
expected = [FlatCache({'a': 3, 'b': 5}), FlatCache({'a': 4, 'b': 6})]
actual = optree.tree_transpose(
outer_treespec, inner_treespec, FlatCache({'a': [3, 4], 'b': [5, 6]})
outer_treespec,
inner_treespec,
FlatCache({'a': [3, 4], 'b': [5, 6]}),
)
assert actual == expected


def test_tree_transpose_with_custom_namespace():
outer_treespec = optree.tree_structure(MyAnotherDict({'a': 1, 'b': 2}), namespace='namespace')
inner_treespec = optree.tree_structure(
MyAnotherDict({'c': 1, 'd': 2, 'e': 3}), namespace='namespace'
MyAnotherDict({'c': 1, 'd': 2, 'e': 3}),
namespace='namespace',
)
nested = MyAnotherDict(
{
'a': MyAnotherDict({'c': 1, 'd': 2, 'e': 3}),
'b': MyAnotherDict({'c': 4, 'd': 5, 'e': 6}),
}
},
)
actual = optree.tree_transpose(outer_treespec, inner_treespec, nested)
assert actual == MyAnotherDict(
{
'c': MyAnotherDict({'a': 1, 'b': 4}),
'd': MyAnotherDict({'a': 2, 'b': 5}),
'e': MyAnotherDict({'a': 3, 'b': 6}),
}
},
)


Expand All @@ -731,36 +749,39 @@ class MyExtraDict(MyAnotherDict):

outer_treespec = optree.tree_structure(MyAnotherDict({'a': 1, 'b': 2}), namespace='namespace')
inner_treespec = optree.tree_structure(
MyExtraDict({'c': 1, 'd': 2, 'e': 3}), namespace='subnamespace'
MyExtraDict({'c': 1, 'd': 2, 'e': 3}),
namespace='subnamespace',
)
nested = MyAnotherDict(
{
'a': MyExtraDict({'c': 1, 'd': 2, 'e': 3}),
'b': MyExtraDict({'c': 4, 'd': 5, 'e': 6}),
}
},
)
with pytest.raises(ValueError, match='Tree structures must have the same namespace.'):
optree.tree_transpose(outer_treespec, inner_treespec, nested)

optree.register_pytree_node_class(MyExtraDict, namespace='namespace')
inner_treespec = optree.tree_structure(
MyExtraDict({'c': 1, 'd': 2, 'e': 3}), namespace='namespace'
MyExtraDict({'c': 1, 'd': 2, 'e': 3}),
namespace='namespace',
)
actual = optree.tree_transpose(outer_treespec, inner_treespec, nested)
assert actual == MyExtraDict(
{
'c': MyAnotherDict({'a': 1, 'b': 4}),
'd': MyAnotherDict({'a': 2, 'b': 5}),
'e': MyAnotherDict({'a': 3, 'b': 6}),
}
},
)


def test_tree_broadcast_prefix():
assert optree.tree_broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
with pytest.raises(
ValueError, match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].')
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
):
optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, (3, 3)]
Expand All @@ -770,15 +791,18 @@ def test_tree_broadcast_prefix():
{'a': 3, 'b': 3, 'c': (None, 3)},
]
assert optree.tree_broadcast_prefix(
[1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
none_is_leaf=True,
) == [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]


def test_broadcast_prefix():
assert optree.broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
with pytest.raises(
ValueError, match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].')
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
):
optree.broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, 3, 3]
Expand All @@ -790,7 +814,9 @@ def test_broadcast_prefix():
3,
]
assert optree.broadcast_prefix(
[1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
none_is_leaf=True,
) == [1, 2, 3, 3, 3, 3]


Expand All @@ -810,13 +836,18 @@ def test_tree_reduce():
assert optree.tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) == 3
assert (
optree.tree_reduce(
lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True
lambda x, y: x and y,
{'x': 1, 'y': (2, None), 'z': 3},
none_is_leaf=True,
)
is None
)
assert (
optree.tree_reduce(
lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, False, none_is_leaf=True
lambda x, y: x and y,
{'x': 1, 'y': (2, None), 'z': 3},
False,
none_is_leaf=True,
)
is False
)
Expand All @@ -826,13 +857,16 @@ def test_tree_sum():
assert optree.tree_sum({'x': 1, 'y': (2, 3)}) == 6
assert optree.tree_sum({'x': 1, 'y': (2, None), 'z': 3}) == 6
with pytest.raises(
TypeError, match=re.escape("unsupported operand type(s) for +: 'int' and 'NoneType'")
TypeError,
match=re.escape("unsupported operand type(s) for +: 'int' and 'NoneType'"),
):
optree.tree_sum({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True)
assert optree.tree_sum({'x': 'a', 'y': ('b', None), 'z': 'c'}, start='') == 'abc'
assert optree.tree_sum({'x': b'a', 'y': (b'b', None), 'z': b'c'}, start=b'') == b'abc'
assert optree.tree_sum(
{'x': [1], 'y': ([2], [None]), 'z': [3]}, start=[], is_leaf=lambda x: isinstance(x, list)
{'x': [1], 'y': ([2], [None]), 'z': [3]},
start=[],
is_leaf=lambda x: isinstance(x, list),
) == [1, 2, None, 3]


Expand Down Expand Up @@ -893,7 +927,7 @@ def test_tree_any():


@parametrize(tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'])
def test_flatten_one_level(tree, none_is_leaf, namespace):
def test_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901
stack = [tree]
actual_leaves = []
expected_leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
Expand All @@ -902,21 +936,24 @@ def test_flatten_one_level(tree, none_is_leaf, namespace):
counter = Counter()
expected_children, one_level_treespec = optree.tree_flatten(
node,
is_leaf=lambda x: counter.increment() > 1,
is_leaf=lambda x: counter.increment() > 1, # noqa: B023
none_is_leaf=none_is_leaf,
namespace=namespace,
)
node_type = type(node)
if one_level_treespec.is_leaf():
assert expected_children == [node]
with pytest.raises(
ValueError, match=re.escape(f'Cannot flatten leaf-type: {node_type}.')
ValueError,
match=re.escape(f'Cannot flatten leaf-type: {node_type}.'),
):
optree.ops.flatten_one_level(node, none_is_leaf=none_is_leaf, namespace=namespace)
actual_leaves.append(node)
else:
children, metadata, entries = optree.ops.flatten_one_level(
node, none_is_leaf=none_is_leaf, namespace=namespace
node,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
assert children == expected_children
if node_type in (list, tuple, type(None)):
Expand Down
Loading

0 comments on commit 828749a

Please sign in to comment.