forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement batching rules for basic arithmetic ops (pytorch#43362)
Summary: Pull Request resolved: pytorch#43362 Batching rules implemented for: addition subtraction division multiplication. I refactored the original `mul_batching_rule` into a templated function so that one can insert arbitrary binary operations into it. add, sub, rsub, mul, and div all work the same way. However, other binary operations work slightly differently (I'm still figuring out the differences and why they're different) so those may need a different implementation. Test Plan: - "pytest test/test_vmap.py -v": new tests Reviewed By: ezyang Differential Revision: D23252317 Pulled By: zou3519 fbshipit-source-id: 6d36cd837a006a2fd31474469323463c1bd797fc
- Loading branch information
1 parent
db78c07
commit c972e62
Showing
2 changed files
with
59 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters