Skip to content

Commit

Permalink
Implement batching rules for basic arithmetic ops (pytorch#43362)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#43362

Batching rules implemented for: addition subtraction division
multiplication.

I refactored the original `mul_batching_rule` into a templated function
so that one can insert arbitrary binary operations into it.

add, sub, rsub, mul, and div all work the same way. However, other
binary operations work slightly differently (I'm still figuring out the
differences and why they're different) so those may need a different
implementation.

Test Plan: - "pytest test/test_vmap.py -v": new tests

Reviewed By: ezyang

Differential Revision: D23252317

Pulled By: zou3519

fbshipit-source-id: 6d36cd837a006a2fd31474469323463c1bd797fc
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 24, 2020
1 parent db78c07 commit c972e62
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 22 deletions.
44 changes: 34 additions & 10 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
return true;
}

Tensor mul_batching_rule(const Tensor& self, const Tensor& other) {
template <typename F, F Func, typename... ExtraArgs>
Tensor binary_pointwise_batching_rule(
const Tensor& self, const Tensor& other, ExtraArgs... args) {
if (self.dim() > 0 && other.dim() > 0) {
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = at::mul(physical_args[0].tensor(), physical_args[1].tensor());
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].newLogicalFromPhysical(result);
}
if (isPhysicalScalarTensor(self)) {
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = at::mul(self, other_physical.tensor());
auto result = Func(self, other_physical.tensor(), args...);
return other_physical.newLogicalFromPhysical(result);
}
if (isPhysicalScalarTensor(other)) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::mul(self_physical.tensor(), other);
auto result = Func(self_physical.tensor(), other, args...);
return self_physical.newLogicalFromPhysical(result);
}

Expand Down Expand Up @@ -120,7 +122,7 @@ Tensor mul_batching_rule(const Tensor& self, const Tensor& other) {
}
auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
{logical_self, logical_other});
auto result = at::mul(physical_args[0].tensor(), physical_args[1].tensor());
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].newLogicalFromPhysical(result);
}

Expand Down Expand Up @@ -289,10 +291,10 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
return self_physical.newLogicalFromPhysical(result);
}

template <Tensor (*Op)(const Tensor&)>
Tensor unary_pointwise_batching_rule(const Tensor& input) {
template <typename F, F Func, typename... ExtraArgs>
Tensor unary_pointwise_batching_rule(const Tensor& input, ExtraArgs... args) {
auto* input_batched = unsafeGetBatchedImpl(input);
auto output_physical = Op(input_batched->value());
auto output_physical = Func(input_batched->value(), args...);
auto old_bdims = input_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
Expand All @@ -319,7 +321,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_remove_batch_dim", native::_remove_batch_dim);

m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
m.impl_UNBOXED("mul.Tensor", mul_batching_rule);

// view operations
m.impl("chunk", chunk_batching_rule);
Expand Down Expand Up @@ -349,7 +350,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("view_as", native::view_as); // composite wrt autograd

// unary pointwise, out-of-place, no additional arguments.
#define UNARY_POINTWISE(op) m.impl(#op, unary_pointwise_batching_rule<at::op>);
#define UNARY_POINTWISE(op) m.impl(#op, \
unary_pointwise_batching_rule<Tensor (*)(const Tensor&), at::op>);
UNARY_POINTWISE(abs);
UNARY_POINTWISE(acos);
UNARY_POINTWISE(asin);
Expand Down Expand Up @@ -391,6 +393,28 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional<MemoryFormat>)
TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
#undef TO_BATCHING_RULE

using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);

#define BINARY_POINTWISE(op) \
m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
m.impl(#op".Scalar", unary_pointwise_batching_rule<TensorScalarType, at::op, Scalar>);
#define BINARY_POINTWISE_VA(op, ...) \
{ \
using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \
m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
m.impl(#op".Scalar", unary_pointwise_batching_rule<Unop, at::op, Scalar, __VA_ARGS__>); \
}

BINARY_POINTWISE_VA(add, Scalar);
BINARY_POINTWISE_VA(sub, Scalar);
BINARY_POINTWISE_VA(rsub, Scalar);
BINARY_POINTWISE(mul);
BINARY_POINTWISE(div);
#undef BINARY_POINTWISE_VA
#undef BINARY_POINTWISE
}

} // namespace at
37 changes: 25 additions & 12 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,32 +446,34 @@ def _assert_uses_vmap_fallback(self, vmap_args, inputs):
self.assertRegex(str(wa[-1].message),
r'falling back to slow \(for loop and stack\) implementation')

def test_fallback_sub(self):
# NB: One day we will implement a batching rule for torch.sub.
def test_fallback_atan2(self):
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
op = torch.atan2

x = torch.randn(5, 7, 11)
y = torch.randn(5, 7, 11)

self._assert_uses_vmap_fallback((torch.sub,), (x, y))
self._assert_uses_vmap_fallback((op,), (x, y))

# fallback on torch.sub
x = torch.randn(7, 11, 5)
y = torch.randn(5, 7, 11)
result = vmap(torch.sub, (2, 0))(x, y)
self.assertEqual(result, x.permute(2, 0, 1) - y)
result = vmap(op, (2, 0))(x, y)
self.assertEqual(result, op(x.permute(2, 0, 1), y))

# fallback on torch.sub, nested vmap
x = torch.randn(7, 11, 5)
y = torch.randn(5, 7, 11)
result = vmap(vmap(torch.sub), (2, 0))(x, y)
self.assertEqual(result, x.permute(2, 0, 1) - y)
result = vmap(vmap(op), (2, 0))(x, y)
self.assertEqual(result, op(x.permute(2, 0, 1), y))

# big batch size (total 10000)
x = torch.randn(100, 10, 10, 5)
y = torch.randn(100, 10, 10)
result = vmap(vmap(vmap(torch.sub)))(x, y)
self.assertEqual(result, x - y.view(100, 10, 10, 1))
result = vmap(vmap(vmap(op)))(x, y)
self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))

def test_fallback_masked_fill(self):
# NB: One day we will implement a batching rule for masked_fill
Expand Down Expand Up @@ -742,9 +744,19 @@ def test_binary_pointwise_ops(self):
def get_number(getter):
return getter([]).item()

def make_case(op, input_getter=TensorFactory.randn):
return (op, input_getter)

cases = [
(torch.mul, TensorFactory.rand),
(lambda x, y: x * y, TensorFactory.randn),
# Basic arithmetic
make_case(torch.add),
make_case(lambda x, y: x + y),
make_case(torch.sub),
make_case(lambda x, y: x - y),
make_case(torch.mul),
make_case(lambda x, y: x * y),
make_case(torch.div, input_getter=TensorFactory.randp1),
make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
]
test = self._vmap_test

Expand Down Expand Up @@ -785,11 +797,12 @@ def get_number(getter):
test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))

if not torch.cuda.is_available():
return
continue

# Test cross-device scalars
number = get_number(getter)
self._test_unary(lambda t: op(t, number), getter, device='cuda')
self._test_unary(lambda t: op(number, t), getter, device='cuda')
self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')

def test_chunk(self):
Expand Down

0 comments on commit c972e62

Please sign in to comment.