diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp index 9aebe68772..cddde5aa10 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -14,6 +14,48 @@ using Tensor = at::Tensor; using torch::autograd::AutogradContext; using torch::autograd::variable_list; +template +void permute_multi_embs_kernel_cpu( + const std::vector& inputs, + std::vector& outputs, + const Tensor& permutes, + const int32_t start, + const int32_t end, + const bool& reverse_permute) { + int32_t in_tensor, out_tensor, in_offset, out_offset, length, next; + for (const auto i : c10::irange(permutes.size(0))) { + int32_t* __restrict__ pp = permutes[i].data_ptr(); + if (reverse_permute) { + out_tensor = pp[PermuteParam::in_tensor]; + in_tensor = pp[PermuteParam::out_tensor]; + out_offset = pp[PermuteParam::in_offset]; + in_offset = pp[PermuteParam::out_offset]; + next = pp[PermuteParam::next]; + } else { + in_tensor = pp[PermuteParam::in_tensor]; + out_tensor = pp[PermuteParam::out_tensor]; + in_offset = pp[PermuteParam::in_offset]; + out_offset = pp[PermuteParam::out_offset]; + } + length = pp[PermuteParam::length]; + if (reverse_permute && next < 0) { + for (auto b : c10::irange(start, end)) { + auto outp = outputs[out_tensor][b].data_ptr() + out_offset; + auto inp = inputs[in_tensor][b].data_ptr() + in_offset; + for (const auto j : c10::irange(length)) { + outp[j] += inp[j]; + } + } + } else { + for (auto b : c10::irange(start, end)) { + auto outp = outputs[out_tensor][b].data_ptr() + out_offset; + auto inp = inputs[in_tensor][b].data_ptr() + in_offset; + std::memcpy(outp, inp, length * inputs[0].itemsize()); + } + } + } +} + std::vector permute_multi_embedding_function_cpu( const at::TensorList& pooled_embs, const Tensor& permutes, @@ -36,40 +78,13 @@ std::vector permute_multi_embedding_function_cpu( outputs.push_back(at::empty({B, out_lengths[i]}, pooled_embs[0].options())); TORCH_CHECK(outputs[i].is_contiguous()); } - at::parallel_for(0, B, 0, [&](int32_t start, int32_t end) { - int32_t in_tensor, out_tensor, in_offset, out_offset, length, next; - for (const auto i : c10::irange(permutes.size(0))) { - int32_t* __restrict__ pp = permutes[i].data_ptr(); - if (reverse_permute) { - out_tensor = pp[PermuteParam::in_tensor]; - in_tensor = pp[PermuteParam::out_tensor]; - out_offset = pp[PermuteParam::in_offset]; - in_offset = pp[PermuteParam::out_offset]; - next = pp[PermuteParam::next]; - } else { - in_tensor = pp[PermuteParam::in_tensor]; - out_tensor = pp[PermuteParam::out_tensor]; - in_offset = pp[PermuteParam::in_offset]; - out_offset = pp[PermuteParam::out_offset]; - } - length = pp[PermuteParam::length]; - if (reverse_permute && next < 0) { - for (auto b : c10::irange(start, end)) { - auto outp = outputs[out_tensor][b].data_ptr() + out_offset; - auto inp = inputs[in_tensor][b].data_ptr() + in_offset; - for (const auto j : c10::irange(length)) { - outp[j] += inp[j]; - } - } - } else { - for (auto b : c10::irange(start, end)) { - auto outp = outputs[out_tensor][b].data_ptr() + out_offset; - auto inp = inputs[in_tensor][b].data_ptr() + in_offset; - std::memcpy(outp, inp, length * pooled_embs[0].itemsize()); - } - } - } - }); + FBGEMM_DISPATCH_FLOATING_TYPES( + pooled_embs[0].scalar_type(), "permute_multi_embs_cpu", [&] { + at::parallel_for(0, B, 0, [&](int32_t start, int32_t end) { + permute_multi_embs_kernel_cpu( + inputs, outputs, permutes, start, end, reverse_permute); + }); + }); return outputs; } @@ -96,12 +111,12 @@ std::vector permute_multi_embedding_function_meta( /// @brief permute and regroup keyed tensors /// /// We often need to regroup keyed tensors (KTs) in a batch. For example, we -/// have two KTs A and B, where A contains the pooled embeddings of two features -/// (keys) F1 and F2, and B contains the pooled embeddings of two features -/// (keys) F3 and F4. Both KTs have the same batch size. +/// have two KTs A and B, where A contains the pooled embeddings of two +/// features (keys) F1 and F2, and B contains the pooled embeddings of two +/// features (keys) F3 and F4. Both KTs have the same batch size. /// -/// We want to permute and regroup the KTs so that in the new KTs, F1 and F3 are -/// grouped together, and F2 and F4 are grouped together. +/// We want to permute and regroup the KTs so that in the new KTs, F1 and F3 +/// are grouped together, and F2 and F4 are grouped together. /// /// **Example:** /// ```python @@ -125,19 +140,20 @@ std::vector permute_multi_embedding_function_meta( /// /// /// @param pooled_embs list of tensors that from KTs' values -/// @param permutes a 2D tensor with each row representing a permute operation. -/// a permute operation is about how to move/copy a feature from the input KT to -/// the output KT. the first column is the input tensor index, and the second -/// column is the output tensor index. the third column is the feature's offset -/// of input tensor, and the fourth column is the feature's offset of output -/// tensor. the fifth column is the length of the feature in a permute, and the -/// last column is a next permute row to operate on (used in backward only). -/// @param in_shapes a 1D tensor with each element representing the length of an -/// input KT. +/// @param permutes a 2D tensor with each row representing a permute +/// operation. a permute operation is about how to move/copy a feature from +/// the input KT to the output KT. the first column is the input tensor index, +/// and the second column is the output tensor index. the third column is the +/// feature's offset of input tensor, and the fourth column is the feature's +/// offset of output tensor. the fifth column is the length of the feature in +/// a permute, and the last column is a next permute row to operate on (used +/// in backward only). +/// @param in_shapes a 1D tensor with each element representing the length of +/// an input KT. /// @param out_shapes a 1D tensor with each element representing the length of /// an output KT. -/// @param out_lengths a 1D vector with each element representing the length of -/// an output KT. +/// @param out_lengths a 1D vector with each element representing the length +/// of an output KT. /// /// @return the values of the output KTs. /// @@ -145,8 +161,8 @@ std::vector permute_multi_embedding_function_meta( /// @note This operator supports autograd, and duplications in the output KTs /// are supported, such as [["F1", "F3"], ["F2", "F4"], ["F1", "F3"]] /// -/// @warning when a feature is omitted from the output KTs, the gradient of the -/// feature won't be set to 0. +/// @warning when a feature is omitted from the output KTs, the gradient of +/// the feature won't be set to 0. /// std::vector permute_multi_embedding_autograd( const at::TensorList& pooled_embs, @@ -227,8 +243,8 @@ Tensor from_cpu(const std::vector& input) { /// feature/key in a KT, and a list of lengths represents a KT /// @param groups List[List[str]], each string represents a feature/key in an /// output KT a list of strings represents one output KT -/// @return tuple of permutes, in_shapes, out_shapes and output_lengths. See the -/// inputs of permute_multi_embedding for more details. The output tensors +/// @return tuple of permutes, in_shapes, out_shapes and output_lengths. See +/// the inputs of permute_multi_embedding for more details. The output tensors /// should be contiguous, and on the same device as the input tensor. /// /// @note This operator doesn't need autograd since it's purely about index.