Skip to content

Commit

Permalink
add test for supporting torch.float16 and torch.bfloat16
Browse files Browse the repository at this point in the history
Summary:
# context
* We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test
* added test to cover the dtype support
* before the operator change, we see the following error
```
Failures:

  1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype
    1) RuntimeError: expected scalar type Float but found Half
      File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype
        outputs = torch.ops.fbgemm.permute_multi_embedding(
      File "torch/_ops.py", line 1113, in __call__
        return self._op(*args, **(kwargs or {}))
```
* suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32`
```
          auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
          auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
```

# changes
* use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`.
* after the change the operator can support `float16`, `bfloat16`

WARNING: somehow this operator still can't support `int` types.

Differential Revision: D57143637
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 14, 2024
1 parent ada1050 commit 0de38a6
Showing 1 changed file with 70 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,48 @@ using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

template <typename scalar_t>
void permute_multi_embs_kernel_cpu(
const std::vector<Tensor>& inputs,
std::vector<Tensor>& outputs,
const Tensor& permutes,
const int32_t start,
const int32_t end,
const bool& reverse_permute) {
int32_t in_tensor, out_tensor, in_offset, out_offset, length, next;
for (const auto i : c10::irange(permutes.size(0))) {
int32_t* __restrict__ pp = permutes[i].data_ptr<int32_t>();
if (reverse_permute) {
out_tensor = pp[PermuteParam::in_tensor];
in_tensor = pp[PermuteParam::out_tensor];
out_offset = pp[PermuteParam::in_offset];
in_offset = pp[PermuteParam::out_offset];
next = pp[PermuteParam::next];
} else {
in_tensor = pp[PermuteParam::in_tensor];
out_tensor = pp[PermuteParam::out_tensor];
in_offset = pp[PermuteParam::in_offset];
out_offset = pp[PermuteParam::out_offset];
}
length = pp[PermuteParam::length];
if (reverse_permute && next < 0) {
for (auto b : c10::irange(start, end)) {
auto outp = outputs[out_tensor][b].data_ptr<scalar_t>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<scalar_t>() + in_offset;
for (const auto j : c10::irange(length)) {
outp[j] += inp[j];
}
}
} else {
for (auto b : c10::irange(start, end)) {
auto outp = outputs[out_tensor][b].data_ptr<scalar_t>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<scalar_t>() + in_offset;
std::memcpy(outp, inp, length * inputs[0].itemsize());
}
}
}
}

std::vector<Tensor> permute_multi_embedding_function_cpu(
const at::TensorList& pooled_embs,
const Tensor& permutes,
Expand All @@ -36,40 +78,13 @@ std::vector<Tensor> permute_multi_embedding_function_cpu(
outputs.push_back(at::empty({B, out_lengths[i]}, pooled_embs[0].options()));
TORCH_CHECK(outputs[i].is_contiguous());
}
at::parallel_for(0, B, 0, [&](int32_t start, int32_t end) {
int32_t in_tensor, out_tensor, in_offset, out_offset, length, next;
for (const auto i : c10::irange(permutes.size(0))) {
int32_t* __restrict__ pp = permutes[i].data_ptr<int32_t>();
if (reverse_permute) {
out_tensor = pp[PermuteParam::in_tensor];
in_tensor = pp[PermuteParam::out_tensor];
out_offset = pp[PermuteParam::in_offset];
in_offset = pp[PermuteParam::out_offset];
next = pp[PermuteParam::next];
} else {
in_tensor = pp[PermuteParam::in_tensor];
out_tensor = pp[PermuteParam::out_tensor];
in_offset = pp[PermuteParam::in_offset];
out_offset = pp[PermuteParam::out_offset];
}
length = pp[PermuteParam::length];
if (reverse_permute && next < 0) {
for (auto b : c10::irange(start, end)) {
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
for (const auto j : c10::irange(length)) {
outp[j] += inp[j];
}
}
} else {
for (auto b : c10::irange(start, end)) {
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
std::memcpy(outp, inp, length * pooled_embs[0].itemsize());
}
}
}
});
FBGEMM_DISPATCH_FLOATING_TYPES(
pooled_embs[0].scalar_type(), "permute_multi_embs_cpu", [&] {
at::parallel_for(0, B, 0, [&](int32_t start, int32_t end) {
permute_multi_embs_kernel_cpu<scalar_t>(
inputs, outputs, permutes, start, end, reverse_permute);
});
});
return outputs;
}

Expand All @@ -96,12 +111,12 @@ std::vector<Tensor> permute_multi_embedding_function_meta(
/// @brief permute and regroup keyed tensors
///
/// We often need to regroup keyed tensors (KTs) in a batch. For example, we
/// have two KTs A and B, where A contains the pooled embeddings of two features
/// (keys) F1 and F2, and B contains the pooled embeddings of two features
/// (keys) F3 and F4. Both KTs have the same batch size.
/// have two KTs A and B, where A contains the pooled embeddings of two
/// features (keys) F1 and F2, and B contains the pooled embeddings of two
/// features (keys) F3 and F4. Both KTs have the same batch size.
///
/// We want to permute and regroup the KTs so that in the new KTs, F1 and F3 are
/// grouped together, and F2 and F4 are grouped together.
/// We want to permute and regroup the KTs so that in the new KTs, F1 and F3
/// are grouped together, and F2 and F4 are grouped together.
///
/// **Example:**
/// ```python
Expand All @@ -125,28 +140,29 @@ std::vector<Tensor> permute_multi_embedding_function_meta(
///
///
/// @param pooled_embs list of tensors that from KTs' values
/// @param permutes a 2D tensor with each row representing a permute operation.
/// a permute operation is about how to move/copy a feature from the input KT to
/// the output KT. the first column is the input tensor index, and the second
/// column is the output tensor index. the third column is the feature's offset
/// of input tensor, and the fourth column is the feature's offset of output
/// tensor. the fifth column is the length of the feature in a permute, and the
/// last column is a next permute row to operate on (used in backward only).
/// @param in_shapes a 1D tensor with each element representing the length of an
/// input KT.
/// @param permutes a 2D tensor with each row representing a permute
/// operation. a permute operation is about how to move/copy a feature from
/// the input KT to the output KT. the first column is the input tensor index,
/// and the second column is the output tensor index. the third column is the
/// feature's offset of input tensor, and the fourth column is the feature's
/// offset of output tensor. the fifth column is the length of the feature in
/// a permute, and the last column is a next permute row to operate on (used
/// in backward only).
/// @param in_shapes a 1D tensor with each element representing the length of
/// an input KT.
/// @param out_shapes a 1D tensor with each element representing the length of
/// an output KT.
/// @param out_lengths a 1D vector with each element representing the length of
/// an output KT.
/// @param out_lengths a 1D vector with each element representing the length
/// of an output KT.
///
/// @return the values of the output KTs.
///
///
/// @note This operator supports autograd, and duplications in the output KTs
/// are supported, such as [["F1", "F3"], ["F2", "F4"], ["F1", "F3"]]
///
/// @warning when a feature is omitted from the output KTs, the gradient of the
/// feature won't be set to 0.
/// @warning when a feature is omitted from the output KTs, the gradient of
/// the feature won't be set to 0.
///
std::vector<Tensor> permute_multi_embedding_autograd(
const at::TensorList& pooled_embs,
Expand Down Expand Up @@ -227,8 +243,8 @@ Tensor from_cpu(const std::vector<index_t>& input) {
/// feature/key in a KT, and a list of lengths represents a KT
/// @param groups List[List[str]], each string represents a feature/key in an
/// output KT a list of strings represents one output KT
/// @return tuple of permutes, in_shapes, out_shapes and output_lengths. See the
/// inputs of permute_multi_embedding for more details. The output tensors
/// @return tuple of permutes, in_shapes, out_shapes and output_lengths. See
/// the inputs of permute_multi_embedding for more details. The output tensors
/// should be contiguous, and on the same device as the input tensor.
///
/// @note This operator doesn't need autograd since it's purely about index.
Expand Down

0 comments on commit 0de38a6

Please sign in to comment.