From 4aac3064f48a1134e32bd8376ad90b0236034ecd Mon Sep 17 00:00:00 2001 From: Hye Soo Yang Date: Fri, 15 Sep 2023 19:42:08 -0700 Subject: [PATCH] Open source op & kernel for `GetMinibatchSplitsWithPhysicalReplica` PiperOrigin-RevId: 565837005 --- tensorflow/core/tpu/kernels/BUILD | 2 + .../kernels/sparse_core_ops_stats_handler.h | 4 + .../tpu/kernels/sparse_core_preprocess_ops.cc | 360 ++++++++++++++++++ .../tpu/kernels/sparse_core_preprocess_ops.h | 40 ++ tensorflow/core/tpu/ops/BUILD | 1 + .../tpu/ops/sparse_core_preprocess_ops.cc | 42 ++ 6 files changed, 449 insertions(+) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index d0bb2cf97d69dc..0c600b45d37dbd 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -152,6 +152,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_highway//:hwy", + "@com_google_highway//hwy/contrib/sort:vqsort", "@local_xla//xla:util", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", diff --git a/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h b/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h index 26c7b76a74c1a3..6b10af885ca328 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h +++ b/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h @@ -20,6 +20,10 @@ limitations under the License. enum class StatsType { NUM_MINIBATCHES_PER_SC, + MAX_IDS_PER_PARTITION, + MAX_UNIQUE_IDS_PER_PARTITION, + IDS_PER_PARTITION, + UNIQUE_IDS_PER_PARTITION, }; class SparseCoreOpsStatsHandler { diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 5b4f6173fb0218..d6453096c9d00a 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -27,6 +27,10 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "highway/hwy/base.h" // from @com_google_highway +#include "highway/hwy/contrib/sort/order.h" // from @com_google_highway +#include "highway/hwy/contrib/sort/vqsort.h" // from @com_google_highway #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/util.h" @@ -518,4 +522,360 @@ REGISTER_KERNEL_BUILDER( GetMinibatchesInCsrWithPhysicalReplicaOp) #endif +GetMinibatchSplitsWithPhysicalReplicaOp:: + GetMinibatchSplitsWithPhysicalReplicaOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_replica", &num_replica_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_count", &sample_count_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_vocab_size", &table_vocab_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_width", &feature_width_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_sc_per_chip", &num_sc_per_chip_)); + OP_REQUIRES(ctx, sample_count_ % num_sc_per_chip_ == 0, + absl::InvalidArgumentError(absl::StrCat( + "sample_count ", sample_count_, + " is not divisible by the number of sparsecores per chip ", + num_sc_per_chip_))); + device_name_ = ctx->device()->name(); + + // Create default instance of stats handler. May get overwritten by subclass. + sprase_core_ops_stats_handler_ = + std::make_unique(); +} + +void GetMinibatchSplitsWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { + // TODO(patn): Allow clients to provide the max_ids and max_uniques directly + // making program_key optional. This would be useful if there's a need to + // use this op without the bridge. + const Tensor* program_key_t; + OP_REQUIRES_OK(ctx, ctx->input("program_key", &program_key_t)); + tstring program_key = program_key_t->vec()(0); + + int64_t per_sparse_core_batch_size = sample_count_ / num_sc_per_chip_; + + int64_t max_ids_per_partition = -1; + int64_t max_unique_ids_per_partition = -1; + + GetMaxIdsAndUniques(ctx, program_key, table_name_, per_sparse_core_batch_size, + feature_width_, &max_ids_per_partition, + &max_unique_ids_per_partition); + + sprase_core_ops_stats_handler_->Record(StatsType::MAX_IDS_PER_PARTITION, + max_ids_per_partition, device_name_, + table_name_); + sprase_core_ops_stats_handler_->Record( + StatsType::MAX_UNIQUE_IDS_PER_PARTITION, max_unique_ids_per_partition, + device_name_, table_name_); + + const Tensor* row_ids; + OP_REQUIRES_OK(ctx, ctx->input("row_ids", &row_ids)); + const Tensor* col_ids; + OP_REQUIRES_OK(ctx, ctx->input("col_ids", &col_ids)); + const Tensor* gains; + OP_REQUIRES_OK(ctx, ctx->input("gains", &gains)); + + const int32 total_id_count = row_ids->NumElements(); + + const int32* row_ids_ptr = row_ids->flat().data(); + const int32* col_ids_ptr = col_ids->flat().data(); + const float* gains_ptr = gains->flat().data(); + + const int num_sc_per_replica = 4; + const int num_physical_replica = num_replica_ * num_sc_per_replica; + + OP_REQUIRES(ctx, sample_count_ % num_sc_per_replica == 0, + absl::InvalidArgumentError( + absl::StrCat("Sample_count has to be multiply of " + "num_sc_per_replica which is 4, but got ", + sample_count_, " samples."))); + + int32 per_sc_sample_count = sample_count_ / num_sc_per_replica; + + const int max_division_level = GetMinibatchMaxDivisionLevel(); + + const int32 kMaxDivisions = 1 << max_division_level; + + // The id counts tensor is the running sum of the number of ids for all + // buckets for all the replicas on each SparseCore. + // This is used in later minibatch forming op to craft each minibatch. + Tensor* id_counts_tensor; + OP_REQUIRES_OK( + ctx, + ctx->allocate_output( + "id_counts", + TensorShape( + {kMaxDivisions * num_sc_per_replica * num_physical_replica + 1}), + &id_counts_tensor)); + int32* id_counts_tensor_ptr = id_counts_tensor->flat().data(); + *id_counts_tensor_ptr = 0; + + const int32_t division_size = + (table_vocab_size_ + kMaxDivisions - 1) / kMaxDivisions; + + // Index pointers into the original row_ids/col_ids/gains arrays. + uint32_t index = 0; + // Splits which should be interpreted as binary format. + // E.g. splits = 11 with table size 1024 indicates: + // 0001011 -> 0001 01 1 + // which mean split at level 0 section 0, level 1 section 0 and level + // 2 section 0. the split points are [128, 256, 512]. + int64 pre_merge_splits = 0; + int64 after_merge_splits = 0; + // Vector of uint64_t storing the col ids in the upper 32 bit and the index + // to the original id array in the lower 32 bit. + std::vector> col_ids_index_list( + num_sc_per_replica, std::vector()); + + // Vector stores the mapping between the index of the id which it can be + // deduped. + // For example: + // [0, 1, 1, 1] means that third and fourth id can be deduped with the + // second id. + std::vector dedup_ids_index_mapping(total_id_count); + + // The gains after the deduplication. If the same ids are in the same + // sample, we will remove that id and add the gains. + std::vector gains_after_dedup(total_id_count); + + // Array which stores the id counts and unique id counts for each minibatch + // bucket on all physical replicas. + std::vector total_id_counter(num_physical_replica * + (kMaxDivisions + 1)); + std::vector total_unique_id_counter(num_physical_replica * + (kMaxDivisions + 1)); + // Array which keeps track of the index of each physical replica. + std::vector per_physical_replica_index(num_physical_replica); + + // Accumulated sum of the id count for each physical replica. + std::vector physical_replica_id_count((num_physical_replica + 1) * + num_sc_per_replica); + + // Id counts for each sc input. + std::vector per_sc_id_count(num_sc_per_replica, 0); + + // Keep track of the maximum number of (unique) ids we see fo this current + // batch. If it gets too close to the configured max, we can increase + // the value in the FDO configs. + int32 this_max_ids = 0; + int32 this_max_uniques = 0; + // Row ids(sample ids) are already sorted. + for (int sc_id = 0; sc_id < num_sc_per_replica; ++sc_id) { + col_ids_index_list[sc_id].reserve(total_id_count); + while (index < total_id_count && + *(row_ids_ptr + index) < (sc_id + 1) * per_sc_sample_count) { + col_ids_index_list[sc_id].push_back( + (static_cast(*(col_ids_ptr + index)) << 32) + index); + ++index; + } + // Perform high speed sorting based on col ids. + hwy::VQSort(col_ids_index_list[sc_id].data(), + col_ids_index_list[sc_id].size(), hwy::SortAscending()); + + memset(total_id_counter.data(), 0, + num_physical_replica * (kMaxDivisions + 1) * sizeof(int32)); + memset(total_unique_id_counter.data(), 0, + num_physical_replica * (kMaxDivisions + 1) * sizeof(int32)); + + // Loop through the col ids to count the ids and unique ids. + int32_t previous_col_id = -1; + int32_t previous_row_id = -1; + uint32_t previous_id_array_index = 0; + for (uint64_t item : col_ids_index_list[sc_id]) { + int32 col_id = item >> 32; + uint32_t id_array_index = item & 0xffffffff; + int32_t row_id = *(row_ids_ptr + id_array_index); + // If the row ids and col ids are both same as the previous one, + // dedup the id by adding the gains. + if (row_id != previous_row_id || col_id != previous_col_id) { + dedup_ids_index_mapping[id_array_index] = id_array_index; + gains_after_dedup[id_array_index] = *(gains_ptr + id_array_index); + int32 replica_id = col_id % num_physical_replica; + int32 bucket_id = col_id / division_size + 1; + uint32_t id_counter_index = + replica_id * (kMaxDivisions + 1) + bucket_id; + total_id_counter[id_counter_index]++; + if (col_id != previous_col_id) + total_unique_id_counter[id_counter_index]++; + } else { + // Dedup the id if both row id and col id is the same. + uint32_t parent_idx = dedup_ids_index_mapping[previous_id_array_index]; + dedup_ids_index_mapping[id_array_index] = parent_idx; + gains_after_dedup[parent_idx] += *(gains_ptr + id_array_index); + } + previous_col_id = col_id; + previous_id_array_index = id_array_index; + previous_row_id = row_id; + } + + for (int replica_id = 0; replica_id < num_physical_replica; ++replica_id) { + absl::Span id_counter = absl::MakeSpan( + total_id_counter.data() + replica_id * (kMaxDivisions + 1), + kMaxDivisions + 1); + absl::Span unique_id_counter = absl::MakeSpan( + total_unique_id_counter.data() + replica_id * (kMaxDivisions + 1), + kMaxDivisions + 1); + for (int i = 1; i < kMaxDivisions + 1; ++i) { + // Check if the smallest division is larger than the max_ids and + // max_unique_ids. + OP_REQUIRES(ctx, + id_counter[i] <= max_ids_per_partition && + unique_id_counter[i] <= max_unique_ids_per_partition, + absl::InvalidArgumentError(absl::StrCat( + "Table ", table_name_, " has too many ids for replica ", + replica_id, " on sparse core ", sc_id, + ". The max_ids_per_partition is ", + max_ids_per_partition, " but got ", id_counter[i], + " ids. The max_unique_ids_per_partition is ", + max_unique_ids_per_partition, " but got ", + unique_id_counter[i], " unique ids.", + " Consider making the max_division_level higher."))); + // Save the running sum of the id counts. + const int global_division_id = + sc_id * num_physical_replica + replica_id; + *(id_counts_tensor_ptr + global_division_id * kMaxDivisions + i) = + *(id_counts_tensor_ptr + global_division_id * kMaxDivisions + i - + 1) + + id_counter[i]; + id_counter[i] += id_counter[i - 1]; + unique_id_counter[i] += unique_id_counter[i - 1]; + } + this_max_ids = std::max(this_max_ids, id_counter[kMaxDivisions]); + this_max_uniques = + std::max(this_max_uniques, unique_id_counter[kMaxDivisions]); + physical_replica_id_count[sc_id * (num_physical_replica + 1) + + replica_id + 1] = + physical_replica_id_count[sc_id * (num_physical_replica + 1) + + replica_id] + + id_counter[kMaxDivisions]; + per_sc_id_count[sc_id] += id_counter[kMaxDivisions]; + + for (int level = 0; level < max_division_level; ++level) { + // Skip this level if the previous level doesn't split. + if (level > 0 && (pre_merge_splits >> ((1LL << (level - 1)) - 1)) == 0) + continue; + int32_t section_size = 1 << (max_division_level - level); + for (int section = 0; section < (1 << level); ++section) { + // Skip this section if the corresponding section on the previous + // level doesn't split. + int pre_start_bit_pos = level > 0 ? (1 << (level - 1)) - 1 : 0; + if (level > 0 && (pre_merge_splits & + (1LL << (pre_start_bit_pos + (section >> 1)))) == 0) + continue; + int32 id_count = id_counter[(section + 1) * section_size] - + id_counter[section * section_size]; + int32 unique_id_count = + unique_id_counter[(section + 1) * section_size] - + unique_id_counter[section * section_size]; + // If the number of ids or unique ids exceeds the limit, We need to + // split. + if (id_count > max_ids_per_partition || + unique_id_count > max_unique_ids_per_partition) { + int start_bit_pos = (1 << level) - 1; + pre_merge_splits = + pre_merge_splits | (1LL << (start_bit_pos + section)); + } + } + } + // Convert the binary representation of the splits into index of + // buckets. + std::vector per_replica_splits = ConvertBinarySplitsToBucketSplits( + pre_merge_splits, max_division_level); + + per_replica_splits.insert(per_replica_splits.begin(), 0); + per_replica_splits.push_back(kMaxDivisions); + + std::vector merged_per_replica_splits; + // Iterate through all the buckets and merge them greedly. + int start_index = 0; + for (int i = 1; i < per_replica_splits.size(); ++i) { + if (unique_id_counter[per_replica_splits[i]] - + unique_id_counter[per_replica_splits[start_index]] <= + max_unique_ids_per_partition && + id_counter[per_replica_splits[i]] - + id_counter[per_replica_splits[start_index]] <= + max_ids_per_partition) { + continue; + } else { + merged_per_replica_splits.push_back(per_replica_splits[i - 1]); + start_index = i - 1; + } + } + // Convert the indexes of the buckets back to the binary representation. + after_merge_splits |= ConvertBucketSplitsToBinarySplits( + merged_per_replica_splits, max_division_level); + } + } + + int64_t updated_total_id_count = absl::c_accumulate(per_sc_id_count, 0); + + Tensor* sorted_row_ids_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("sorted_row_ids", + TensorShape({updated_total_id_count}), + &sorted_row_ids_tensor)); + Tensor* sorted_col_ids_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("sorted_col_ids", + TensorShape({updated_total_id_count}), + &sorted_col_ids_tensor)); + Tensor* sorted_gains_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + "sorted_gains", TensorShape({updated_total_id_count}), + &sorted_gains_tensor)); + + int32_t* sorted_row_ids_tensor_ptr = + sorted_row_ids_tensor->flat().data(); + int32_t* sorted_col_ids_tensor_ptr = + sorted_col_ids_tensor->flat().data(); + float* sorted_gains_tensor_ptr = sorted_gains_tensor->flat().data(); + + int32_t previous_index = 0; + + for (int sc_id = 0; sc_id < num_sc_per_replica; ++sc_id) { + memset(per_physical_replica_index.data(), 0, + num_physical_replica * sizeof(int32)); + for (uint64_t item : col_ids_index_list[sc_id]) { + uint32_t id_array_index = item & 0xffffffff; + // Skip deduped ids. + if (id_array_index != dedup_ids_index_mapping[id_array_index]) { + continue; + } + int32_t col_id = item >> 32; + int32_t replica_id = col_id % num_physical_replica; + int32_t main_index = + per_physical_replica_index[replica_id] + previous_index + + physical_replica_id_count[sc_id * (num_physical_replica + 1) + + replica_id]; + *(sorted_row_ids_tensor_ptr + main_index) = + *(row_ids_ptr + id_array_index) % per_sc_sample_count; + *(sorted_col_ids_tensor_ptr + main_index) = col_id / num_physical_replica; + // Use the updated gains instead. + *(sorted_gains_tensor_ptr + main_index) = + gains_after_dedup[id_array_index]; + per_physical_replica_index[replica_id]++; + } + previous_index += per_sc_id_count[sc_id]; + } + + sprase_core_ops_stats_handler_->Record( + StatsType::IDS_PER_PARTITION, this_max_ids, device_name_, table_name_); + sprase_core_ops_stats_handler_->Record(StatsType::UNIQUE_IDS_PER_PARTITION, + this_max_uniques, device_name_, + table_name_); + + CalculateHeadroom(this_max_ids, this_max_uniques, program_key, + max_ids_per_partition, max_unique_ids_per_partition); + + Tensor* splits_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("splits", TensorShape({}), &splits_tensor)); + splits_tensor->flat()(0) = after_merge_splits; +} + +#ifdef LIBTPU_ON_GCE +REGISTER_KERNEL_BUILDER( + Name("GetMinibatchSplitsWithPhysicalReplica").Device(DEVICE_CPU), + GetMinibatchSplitsWithPhysicalReplicaOp) +#endif + } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index d9d9ab730cd32a..481f8f6730b7f1 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -15,11 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ #define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ +#include +#include #include #include "absl/status/status.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h" namespace tensorflow { @@ -85,6 +89,42 @@ class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel { std::string device_name_; }; +class GetMinibatchSplitsWithPhysicalReplicaOp : public OpKernel { + public: + explicit GetMinibatchSplitsWithPhysicalReplicaOp(OpKernelConstruction* ctx); + ~GetMinibatchSplitsWithPhysicalReplicaOp() override = default; + GetMinibatchSplitsWithPhysicalReplicaOp( + const GetMinibatchSplitsWithPhysicalReplicaOp&) = delete; + GetMinibatchSplitsWithPhysicalReplicaOp& operator=( + const GetMinibatchSplitsWithPhysicalReplicaOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + protected: + virtual void GetMaxIdsAndUniques(OpKernelContext* ctx, + const std::string& program_key, + const std::string& table_name, + int64_t num_samples_per_sparse_core, + int64_t feature_width, + int64_t* max_ids_per_partition, + int64_t* max_unique_ids_per_partition) {} + virtual void CalculateHeadroom(int32 this_max_ids, int32 this_max_uniques, + tstring program_key, + int64_t max_ids_per_partition, + int64_t max_unique_ids_per_partition) {} + + std::string device_name_; + std::string table_name_; + std::unique_ptr sprase_core_ops_stats_handler_; + + private: + int num_replica_ = 1; + int sample_count_ = 1; + int table_vocab_size_ = 1; + int feature_width_ = 1; + int64_t num_sc_per_chip_; +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD index d073429eec6774..774172befa4847 100644 --- a/tensorflow/core/tpu/ops/BUILD +++ b/tensorflow/core/tpu/ops/BUILD @@ -186,6 +186,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/tpu/kernels:sparse_core_ops_utils", "@local_xla//xla:util", ], alwayslink = 1, diff --git a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc index 321fff6b955cdd..6609b754c417be 100644 --- a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h" #include "tsl/platform/errors.h" namespace tensorflow { @@ -106,4 +107,45 @@ REGISTER_OP("GetMinibatchesInCsrWithPhysicalReplica") return OkStatus(); }); +REGISTER_OP("GetMinibatchSplitsWithPhysicalReplica") + .Input("program_key: string") + .Input("row_ids: int32") + .Input("col_ids: int32") + .Input("gains: float32") + .Output("sorted_row_ids: int32") + .Output("sorted_col_ids: int32") + .Output("sorted_gains: float32") + .Output("splits: int64") + .Output("id_counts: int32") + .Attr("sample_count : int >= 1") + .Attr("num_replica: int >= 1") + .Attr("table_vocab_size: int >= 1") + .Attr("feature_width: int >= 1") + .Attr("num_sc_per_chip: int >= 1") + .Attr("table_name: string") + .Attr("mini_batch_splits: string") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShapeOfRank(1)); + c->set_output(1, c->UnknownShapeOfRank(1)); + c->set_output(2, c->UnknownShapeOfRank(1)); + int32 num_replica; + TF_RETURN_IF_ERROR(c->GetAttr("num_replica", &num_replica)); + + int32 num_sc_per_chip; + TF_RETURN_IF_ERROR(c->GetAttr("num_sc_per_chip", &num_sc_per_chip)); + + const int max_division_level = GetMinibatchMaxDivisionLevel(); + + const int num_physical_replica = num_replica * num_sc_per_chip; + + const int32 kMaxDivisions = 1 << max_division_level; + + c->set_output(3, c->Scalar()); + c->set_output( + 4, c->MakeShape( + {num_physical_replica * kMaxDivisions * num_sc_per_chip + 1})); + return OkStatus(); + }); + } // namespace tensorflow