Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Xnnpack] Accept default padding value for torch.constant_pad_nd #7469

Merged
merged 1 commit into from
Jan 3, 2025

Conversation

pssrawat
Copy link
Contributor

@pssrawat pssrawat commented Jan 2, 2025

Summary:
xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, we get the following error in xnnpack delegation:
{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Differential Revision: D67756862

Copy link

pytorch-bot bot commented Jan 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7469

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f3904a8 with merge base cb568fa (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 2, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

@pssrawat pssrawat changed the title Accept default padding value [Xnnpack] Accept default padding value for torch.constant_pad_nd Jan 2, 2025
pssrawat added a commit to pssrawat/executorch that referenced this pull request Jan 2, 2025
Summary:

xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`.

{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Differential Revision: D67756862
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

pssrawat added a commit to pssrawat/executorch that referenced this pull request Jan 2, 2025
Summary:

xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`.

{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Reviewed By: tarun292

Differential Revision: D67756862
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

pssrawat added a commit to pssrawat/executorch that referenced this pull request Jan 2, 2025
Summary:

xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`.

{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Reviewed By: tarun292

Differential Revision: D67756862
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

Summary:
Pull Request resolved: pytorch#7469

xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`.

{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Reviewed By: tarun292

Differential Revision: D67756862
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67756862

@facebook-github-bot facebook-github-bot merged commit e66cdaf into pytorch:main Jan 3, 2025
44 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants