Skip to content

Commit

Permalink
Support original indices for FBGEMM block bucketization flag (#2999)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2999

For ZCH we want to operate in global id space; which requires expanding fbgemm block bucketization kernels to operate on global ids, not local.  Purposed 'maybe_keep_orig_idx' -> 'keep_orig_idx'

Reviewed By: sryap, dracifer

Differential Revision: D61102970

fbshipit-source-id: 97702627c05c6f9fbb979e684b1896ac99847e51
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Aug 15, 2024
1 parent adb5b83 commit 7143818
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 81 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def block_bucketize_sparse_features_meta(
batch_size_per_feature: Optional[torch.Tensor] = None,
max_B: int = -1,
block_bucketize_pos: Optional[torch.Tensor] = None,
maybe_keep_orig_idx: bool = False,
keep_orig_idx: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ block_bucketize_sparse_features_cuda(
const std::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool maybe_keep_orig_idx);
const bool keep_orig_idx);

std::tuple<
at::Tensor,
Expand All @@ -198,7 +198,7 @@ block_bucketize_sparse_features_cpu(
const std::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool maybe_keep_orig_idx);
const bool keep_orig_idx);

std::tuple<
at::Tensor,
Expand All @@ -220,7 +220,7 @@ block_bucketize_sparse_features_inference_cuda(
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx);
const bool keep_orig_idx);

///@ingroup sparse-data-cuda
at::Tensor populate_bucketized_permute_cuda(
Expand Down Expand Up @@ -249,7 +249,7 @@ block_bucketize_sparse_features_inference_cpu(
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx);
const bool keep_orig_idx);

///@ingroup sparse-data-cpu
at::Tensor populate_bucketized_permute_cpu(
Expand Down
72 changes: 45 additions & 27 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
const offset_t* const __restrict__ block_bucketize_pos_concat,
const offset_t* const __restrict__ block_bucketize_pos_offsets,
const offset_t* const __restrict__ indices_to_lb,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
const auto bt_start = blockIdx.x * blockDim.y + threadIdx.y;
const auto stride = gridDim.x * blockDim.y;
Expand Down Expand Up @@ -190,17 +190,26 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
uindex_t p = 0;
uindex_t new_idx = 0;
if (!use_block_bucketize_pos) {
if (!use_block_bucketize_pos) { // uniform bucket sizes
p = idx < blk_size * my_size ? idx / blk_size : idx % my_size;
new_idx = idx < blk_size * my_size
? idx % blk_size
: (maybe_keep_orig_idx ? idx : idx / my_size);
} else {
if (keep_orig_idx) {
new_idx = idx;
} else if (idx < blk_size * my_size) {
new_idx = idx % blk_size;
} else {
new_idx = idx / my_size;
}
} else { // variable bucket sizes
uindex_t lb = indices_to_lb[i];
p = lb < my_size ? lb : idx % my_size;
new_idx = lb < my_size ? idx -
block_bucketize_pos_concat[lb + block_bucketize_pos_offsets[t]]
: (maybe_keep_orig_idx ? idx : idx / my_size);
if (keep_orig_idx) {
new_idx = idx;
} else if (lb < my_size) {
new_idx = idx -
block_bucketize_pos_concat[lb + block_bucketize_pos_offsets[t]];
} else {
new_idx = idx / my_size;
}
}
static_assert(
sizeof(unsigned long long int) == sizeof(uint64_t),
Expand Down Expand Up @@ -254,7 +263,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
const offset_t* const __restrict__ block_bucketize_pos_concat,
const offset_t* const __restrict__ block_bucketize_pos_offsets,
const offset_t* const __restrict__ indices_to_lb,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
Expand All @@ -276,15 +285,24 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
uindex_t new_idx = 0;
if (!use_block_bucketize_pos) {
p = idx < blk_size * my_size ? idx / blk_size : idx % my_size;
new_idx = idx < blk_size * my_size
? idx % blk_size
: (maybe_keep_orig_idx ? idx : idx / my_size);
if (keep_orig_idx) {
new_idx = idx;
} else if (idx < blk_size * my_size) {
new_idx = idx % blk_size;
} else {
new_idx = idx / my_size;
}
} else {
uindex_t lb = indices_to_lb[i];
p = lb < my_size ? lb : idx % my_size;
new_idx = lb < my_size ? idx -
block_bucketize_pos_concat[lb + block_bucketize_pos_offsets[t]]
: (maybe_keep_orig_idx ? idx : idx / my_size);
if (keep_orig_idx) {
new_idx = idx;
} else if (lb < my_size) {
new_idx = idx -
block_bucketize_pos_concat[lb + block_bucketize_pos_offsets[t]];
} else {
new_idx = idx / my_size;
}
}
uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
new_indices_data[pos] = new_idx;
Expand Down Expand Up @@ -345,7 +363,7 @@ _block_bucketize_sparse_features_cuda(
const int64_t max_B,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);

CUDA_DEVICE_GUARD(lengths);
Expand Down Expand Up @@ -530,7 +548,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<offset_t>() \
: static_cast<offset_t*>(nullptr), \
maybe_keep_orig_idx); \
keep_orig_idx); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
}); \
Expand Down Expand Up @@ -586,7 +604,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<offset_t>() \
: static_cast<offset_t*>(nullptr), \
maybe_keep_orig_idx); \
keep_orig_idx); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
});
Expand Down Expand Up @@ -698,7 +716,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value()
? indices_to_lb.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr),
maybe_keep_orig_idx);
keep_orig_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -761,7 +779,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value()
? indices_to_lb.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr),
maybe_keep_orig_idx);
keep_orig_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -816,7 +834,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value()
? indices_to_lb.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr),
maybe_keep_orig_idx);
keep_orig_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -869,7 +887,7 @@ _block_bucketize_sparse_features_cuda(
block_bucketize_pos.has_value()
? indices_to_lb.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr),
maybe_keep_orig_idx);
keep_orig_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -904,7 +922,7 @@ block_bucketize_sparse_features_cuda(
const std::optional<Tensor>& batch_size_per_feature,
const int64_t max_B,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
Tensor new_lengths;
Tensor new_indices;
std::optional<Tensor> new_weights;
Expand All @@ -929,7 +947,7 @@ block_bucketize_sparse_features_cuda(
max_B,
block_bucketize_pos,
false,
maybe_keep_orig_idx);
keep_orig_idx);
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
}
Expand All @@ -954,7 +972,7 @@ block_bucketize_sparse_features_inference_cuda(
const int64_t max_B,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
return _block_bucketize_sparse_features_cuda(
lengths,
indices,
Expand All @@ -967,7 +985,7 @@ block_bucketize_sparse_features_inference_cuda(
max_B,
block_bucketize_pos,
return_bucket_mapping,
maybe_keep_orig_idx);
keep_orig_idx);
}
DLL_PUBLIC Tensor populate_bucketized_permute_cuda(
Expand Down
47 changes: 28 additions & 19 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void _block_bucketize_sparse_features_cpu_kernel(
const std::optional<Tensor>& batch_size_per_feature,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const std::optional<Tensor>& bucket_mapping,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
// allocate tensors and buffers
const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
Expand Down Expand Up @@ -408,15 +408,24 @@ void _block_bucketize_sparse_features_cpu_kernel(
if (variable_bucket_sizes) {
int64_t lb = lower_bounds[i];
p = lb < my_size ? lb : idx % my_size;
new_idx = lb < my_size ? idx - bucketize_offset[lb]
: (maybe_keep_orig_idx ? idx : idx / my_size);

} else {
p = idx < static_cast<uindex_t>(blk_size * my_size) ? idx / blk_size
: idx % my_size;
new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
? idx % blk_size
: (maybe_keep_orig_idx ? idx : idx / my_size);
if (keep_orig_idx) {
new_idx = idx;
} else if (lb < my_size) {
new_idx = idx - bucketize_offset[lb];
} else {
new_idx = idx / my_size;
}
} else { // uniform bucket size

const uindex_t ub = static_cast<uindex_t>(blk_size * my_size);
p = idx < ub ? idx / blk_size : idx % my_size;
if (keep_orig_idx) {
new_idx = idx;
} else if (idx < ub) {
new_idx = idx % blk_size;
} else {
new_idx = idx / my_size;
}
}
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
new_indices_data[pos] = new_idx;
Expand Down Expand Up @@ -1026,7 +1035,7 @@ _block_bucketize_sparse_features_cpu(
const int64_t /* max_batch_size */, // Only used in GPU variant
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
Expand Down Expand Up @@ -1074,7 +1083,7 @@ _block_bucketize_sparse_features_cpu(
batch_size_per_feature, \
block_bucketize_pos, \
bucket_mapping, \
maybe_keep_orig_idx); \
keep_orig_idx); \
}); \
}); \
});
Expand Down Expand Up @@ -1109,7 +1118,7 @@ _block_bucketize_sparse_features_cpu(
batch_size_per_feature, \
block_bucketize_pos, \
bucket_mapping, \
maybe_keep_orig_idx); \
keep_orig_idx); \
}); \
});
const auto lengths_sum = indices.numel();
Expand Down Expand Up @@ -1178,7 +1187,7 @@ block_bucketize_sparse_features_cpu(
const std::optional<Tensor>& batch_size_per_feature,
const int64_t /* max_batch_size */, // Only used in GPU variant
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
Tensor new_lengths;
Tensor new_indices;
std::optional<Tensor> new_weights;
Expand All @@ -1203,7 +1212,7 @@ block_bucketize_sparse_features_cpu(
-1, /* placeholder for max_batch_size */
block_bucketize_pos,
false,
maybe_keep_orig_idx);
keep_orig_idx);
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
}

Expand All @@ -1226,7 +1235,7 @@ block_bucketize_sparse_features_inference_cpu(
const int64_t /* max_batch_size */, // Only used in GPU variant
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool maybe_keep_orig_idx) {
const bool keep_orig_idx) {
return _block_bucketize_sparse_features_cpu(
lengths,
indices,
Expand All @@ -1239,7 +1248,7 @@ block_bucketize_sparse_features_inference_cpu(
-1, /* placeholder for max_batch_size */
block_bucketize_pos,
return_bucket_mapping,
maybe_keep_orig_idx);
keep_orig_idx);
}

// This function partitions sparse features
Expand Down Expand Up @@ -3071,9 +3080,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"populate_bucketized_permute(Tensor lengths, Tensor bucketized_lengths, Tensor bucket_mapping) -> Tensor");
m.def(
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool maybe_keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
m.def(
"block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool maybe_keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)");
"block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)");
m.def(
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
m.def(
Expand Down
Loading

0 comments on commit 7143818

Please sign in to comment.