Skip to content

Commit

Permalink
Add test case for BC Unit test (pytorch#2377)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2377

We need another test case for `keyword_only` arguments - where new arguments aren't necessarily added to the end of the signature.

Also updated logic for `is_signature_compatible` to handle/check for this case

Reviewed By: PaulZhang12

Differential Revision: D62470290

fbshipit-source-id: df08d3b0f28fd2f249e9b90ead825b058ed5ebb2
  • Loading branch information
aporialiao authored and facebook-github-bot committed Sep 11, 2024
1 parent ce378c3 commit 30b88f1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
11 changes: 11 additions & 0 deletions torchrec/schema/test_schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ def test_func_keyword_arg_added(
) -> int:
return a

def test_func_keyword_arg_added_in_middle(
a: int,
b: float,
*,
c: int,
e: float = 1.0,
d: float = 1.0,
**kwargs: Dict[str, Any],
) -> int:
return a

def test_func_keyword_arg_shifted(
a: int, b: float, *, d: float = 1.0, c: int, **kwargs: Dict[str, Any]
) -> int:
Expand Down
19 changes: 13 additions & 6 deletions torchrec/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,23 @@ def is_signature_compatible(
expected_args = list(previous_signature.parameters.values())
current_args = list(current_signature.parameters.values())

# Store the names of all keyword only arguments
# to check if all expected keyword only arguments
# are present in current signature
expected_keyword_only_args = set()
current_keyword_only_args = set()

for i in range(len(expected_args)):
expected_arg = expected_args[i]
expected_args_len = len(expected_args)

for i in range(len(current_args)):
current_arg = current_args[i]
if current_arg.kind == current_arg.KEYWORD_ONLY:
current_keyword_only_args.add(current_arg.name)

if i >= expected_args_len:
continue

expected_arg = expected_args[i]

# If the kinds of arguments are different, BC is broken
# unless current arg is a keyword argument
Expand All @@ -65,11 +76,7 @@ def is_signature_compatible(
if expected_arg.default != current_arg.default:
return False
elif expected_arg.kind == expected_arg.KEYWORD_ONLY:
# Store the names of all keyword only arguments
# to check if all expected keyword only arguments
# are present in current signature
expected_keyword_only_args.add(expected_arg.name)
current_keyword_only_args.add(current_arg.name)

# All kwargs in expected signature must be present in current signature
for kwarg in expected_keyword_only_args:
Expand Down

0 comments on commit 30b88f1

Please sign in to comment.