diff --git a/tensorflow/core/kernels/unique_ali_op.cc b/tensorflow/core/kernels/unique_ali_op.cc index 28b5dad1990..6f8e3d1bb76 100644 --- a/tensorflow/core/kernels/unique_ali_op.cc +++ b/tensorflow/core/kernels/unique_ali_op.cc @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/task_runner.h" #include "tensorflow/core/kernels/unique_ali_op_util.h" -#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/env_var.h" namespace tensorflow { @@ -41,40 +41,43 @@ const char* kStlHashMapString = "STL"; const char* kAbslHashMapString = "ABSL"; const char* kGoogleHashMapString = "GOOGLE"; const int64 kDefaultUniqueRatioHint = 4; -} +} // namespace template class UniqueAliOp : public OpKernel { public: explicit UniqueAliOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv, - kPartitionSize, &partition_size_)); - OP_REQUIRES(context, partition_size_ > 0, - errors::InvalidArgument("Invaild PARTITION_SIZE=", - partition_size_)); + OP_REQUIRES_OK( + context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv, kPartitionSize, + &partition_size_)); + OP_REQUIRES( + context, partition_size_ > 0, + errors::InvalidArgument("Invaild PARTITION_SIZE=", partition_size_)); - OP_REQUIRES_OK(context, ReadBoolFromEnvVar(kUniqueOpSerialEnv, - false, &serial_)); + OP_REQUIRES_OK(context, + ReadBoolFromEnvVar(kUniqueOpSerialEnv, false, &serial_)); // NOTE(zycao>: Hash map insertion and lookup performance is dominating in // Unique Op. Based on benchmark results, 'google::dense_hash_map' will be // used as default for most key types except string. // - // By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a particular - // hash map could be seleteed to use. Possible choices are listed below: + // By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a + // particular hash map could be seleteed to use. Possible choices are listed + // below: // "MULTIMAP" for multimap parrallel process, // "STL" for std::unordred_map, // "ABSL" for absl::flat_hash_map, // "GOOGLE" for google::dense_hash_map. std::string hash_map_str; - OP_REQUIRES_OK(context, ReadStringFromEnvVar(kUniqueOpHashMapEnv, - kGoogleHashMapString, - &hash_map_str)); + OP_REQUIRES_OK( + context, ReadStringFromEnvVar(kUniqueOpHashMapEnv, kGoogleHashMapString, + &hash_map_str)); std::transform(hash_map_str.begin(), hash_map_str.end(), hash_map_str.begin(), ::toupper); OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpUniqRatioHint, - kDefaultUniqueRatioHint, &unique_ratio_hint_)); + kDefaultUniqueRatioHint, + &unique_ratio_hint_)); OP_REQUIRES(context, unique_ratio_hint_ > 0, errors::InvalidArgument("Invaild ", kUniqueOpUniqRatioHint, "=", unique_ratio_hint_)); @@ -83,7 +86,8 @@ class UniqueAliOp : public OpKernel { map_flag_ = MULTIMAP; static char print_once = [] { LOG(INFO) << "MultiMapCompute preserved " - "dense hash map key: " << kPreseverdEmptyKey; + "dense hash map key: " + << kPreseverdEmptyKey; return '\0'; }(); } else if (!hash_map_str.compare(kStlHashMapString)) { @@ -95,7 +99,6 @@ class UniqueAliOp : public OpKernel { } else { map_flag_ = GOOGLE; } - } void Compute(OpKernelContext* context) override { @@ -110,16 +113,14 @@ class UniqueAliOp : public OpKernel { Tensor output; Tensor output_counter; if (context->num_inputs() == 1) { - UniqueWithoutAxis(context, input, - &idx, &output, &output_counter, num_outputs(), - partition_size_, serial_, unique_ratio_hint_, - map_flag_); + UniqueWithoutAxis( + context, input, &idx, &output, &output_counter, num_outputs(), + partition_size_, serial_, unique_ratio_hint_, map_flag_); } else { const Tensor& axis_tensor = context->input(1); - UniqueWithAxis(context, input, - axis_tensor, &idx, &output, &output_counter, - num_outputs(), partition_size_, serial_, - unique_ratio_hint_, map_flag_); + UniqueWithAxis(context, input, axis_tensor, &idx, &output, + &output_counter, num_outputs(), partition_size_, + serial_, unique_ratio_hint_, map_flag_); } context->set_output(0, output); context->set_output(1, idx); @@ -128,33 +129,65 @@ class UniqueAliOp : public OpKernel { } } + protected: bool serial_ = false; int64 partition_size_ = 0; int64 unique_ratio_hint_; UniqueMaps map_flag_ = GOOGLE; // "GOOGLE" dense hash map is default }; +template +class UniqueWithCountAliOp : public UniqueAliOp { + using UniqueAliOp::serial_; + using UniqueAliOp::partition_size_; + using UniqueAliOp::unique_ratio_hint_; + using UniqueAliOp::map_flag_; + using OpKernel::num_outputs; + + public: + explicit UniqueWithCountAliOp(OpKernelConstruction* context) + : UniqueAliOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("N", &num_sparse_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + Tensor idx; + Tensor output; + Tensor output_counter; + UniqueWithExtraCounts( + context, input, &idx, &output, &output_counter, num_outputs(), + partition_size_, serial_, unique_ratio_hint_, num_sparse_, map_flag_); + context->set_output(0, output); + context->set_output(1, idx); + context->set_output(2, output_counter); + } + + private: + int num_sparse_; +}; + #define REGISTER_UNIQUE(type) \ REGISTER_KERNEL_BUILDER(Name("Unique") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); \ + UniqueAliOp) \ REGISTER_KERNEL_BUILDER(Name("Unique") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); \ + UniqueAliOp) \ REGISTER_KERNEL_BUILDER(Name("UniqueV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); \ + UniqueAliOp) \ REGISTER_KERNEL_BUILDER(Name("UniqueV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); \ + UniqueAliOp) \ REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ @@ -164,7 +197,7 @@ class UniqueAliOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); \ + UniqueAliOp) \ REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ @@ -174,7 +207,17 @@ class UniqueAliOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp) + UniqueAliOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueWithCountAliOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueWithCountAliOp) TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); REGISTER_UNIQUE(string) #undef REGISTER_UNIQUE @@ -198,12 +241,22 @@ REGISTER_UNIQUE(string) .HostMemory("count") \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ - UniqueAliOp); + UniqueAliOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueWithCountAliOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueWithCountAliOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); REGISTER_UNIQUE(string) #undef REGISTER_UNIQUE -#endif //GOOGLE_CUDA - +#endif // GOOGLE_CUDA + #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("Unique") .Device(DEVICE_SYCL) diff --git a/tensorflow/core/kernels/unique_ali_op_util.h b/tensorflow/core/kernels/unique_ali_op_util.h index 6b59ba26e81..0a52d8864e9 100644 --- a/tensorflow/core/kernels/unique_ali_op_util.h +++ b/tensorflow/core/kernels/unique_ali_op_util.h @@ -191,7 +191,8 @@ void NewSizes(OpKernelContext* context, const Tensor& input, template void SerialComputeV1(OpKernelContext* context, const Tensor& input, - Tensor* idx, int64 axis, int64* uniq_size, Tensor* output) { + Tensor* idx, int64 axis, int64* uniq_size, int num_sparse, + google::dense_hash_map* counter_map, Tensor* output) { auto Tin = input.flat(); const int64 N = input.NumElements(); auto idx_vec = idx->template vec(); @@ -205,7 +206,23 @@ void SerialComputeV1(OpKernelContext* context, const Tensor& input, ++j; } } - + + counter_map->set_empty_key(std::numeric_limits::max()); + counter_map->resize(2 * N); + for (int i = 0; i < num_sparse; ++i) { + const Tensor& indices_tensor = context->input(1 + i); + auto extra_ids_vec = indices_tensor.template vec(); + const Tensor& counter_tensor = context->input(1 + num_sparse + i); + auto counter_vec = counter_tensor.template vec(); + for (int64 k = 0; k < extra_ids_vec.size(); ++k) { + auto ids = extra_ids_vec(k); + auto idx_it = uniq.find(ids); + if (idx_it != uniq.end()) { + counter_map->emplace(idx_it->second, counter_vec(k)); + } + } + } + *uniq_size = static_cast(uniq.size()); TensorShape output_shape(input.shape()); output_shape.set_dim(axis, *uniq_size); @@ -223,7 +240,8 @@ void SerialComputeV1(OpKernelContext* context, const Tensor& input, template void ParallelComputeV1(OpKernelContext* context, const Tensor& input, - Tensor* idx, int64 axis, int64* uniq_size, Tensor* output) { + Tensor* idx, int64 axis, int64* uniq_size, int num_sparse, + google::dense_hash_map* counter_map, Tensor* output) { // Struct INode was used to store an inverse mapping for each node in the // hash map container. struct INode { @@ -415,6 +433,25 @@ void ParallelComputeV1(OpKernelContext* context, const Tensor& input, TaskRunner t3_runner(GlobalIndexTask, thread_pool, num_tasks_t1); t3_runner.Run(); + counter_map->set_empty_key(std::numeric_limits::max()); + counter_map->resize(2 * N); + for (int i = 0; i < num_sparse; ++i) { + const Tensor& indices_tensor = context->input(1 + i); + auto extra_ids_vec = indices_tensor.template vec(); + const Tensor& counter_tensor = context->input(1 + num_sparse + i); + auto counter_vec = counter_tensor.template vec(); + for (int64 k = 0; k < extra_ids_vec.size(); ++k) { + auto ids = extra_ids_vec(k); + for (int j = 0; j < num_tasks_t1; ++j) { + const INode* inode = uniq_maps[j].GetINodeByKey(ids); + if (inode != nullptr) { + counter_map->emplace(inode->index_, counter_vec(k)); + continue; + } + } + } + } + // Parallel Step 4: Write output indicies Tensor. int32 max_tasks_t4 = (N + kPartitionSize - 1) / kPartitionSize; int32 num_tasks_t4 = std::max(std::min(max_threads, max_tasks_t4), 1); @@ -447,8 +484,8 @@ void ParallelComputeV1(OpKernelContext* context, const Tensor& input, template void MultiMapCompute(OpKernelContext* context, const Tensor& input, Tensor* idx, int64 axis, int64* uniq_size_out, - int32 num_buckets, int64 unique_ratio_hint, - Tensor* output) { + int32 num_buckets, int64 unique_ratio_hint, int num_sparse, + google::dense_hash_map* counter_map, Tensor* output) { auto Tin = input.vec(); const int64 N = input.NumElements(); @@ -529,6 +566,24 @@ void MultiMapCompute(OpKernelContext* context, const Tensor& input, } int64 uniq_size = global_offsets[num_buckets - 1] + uniq_maps[num_buckets - 1].size(); + + counter_map->set_empty_key(std::numeric_limits::max()); + counter_map->resize(2 * uniq_size); + + google::dense_hash_map extra_unique_id_map; + extra_unique_id_map.set_empty_key(std::numeric_limits::max()); + extra_unique_id_map.resize(2 * uniq_size); + for (int i = 0; i < num_sparse; ++i) { + const Tensor& indices_tensor = context->input(1 + i); + auto extra_ids_vec = indices_tensor.template vec(); + const Tensor& counter_tensor = context->input(1 + num_sparse + i); + auto counter_vec = counter_tensor.template vec(); + for (int64 k = 0; k < extra_ids_vec.size(); ++k) { + auto ids = extra_ids_vec(k); + auto counts = counter_vec(k); + extra_unique_id_map.emplace(ids, counts); + } + } *uniq_size_out = uniq_size; AllocatorAttributes attr; @@ -539,7 +594,7 @@ void MultiMapCompute(OpKernelContext* context, const Tensor& input, auto key_output_vec = output->template vec(); auto OutputTask = [&key_output_vec, &uniq_maps, &global_offsets, - &Tin, &idx_vec, &map_parter] + &Tin, &idx_vec, &map_parter, &counter_map, extra_unique_id_map] (int32 task_id, int32 num_tasks) { TIndex offset = global_offsets[task_id]; for (auto iter = uniq_maps[task_id].begin(); iter != uniq_maps[task_id].end(); ++iter) { @@ -553,7 +608,10 @@ void MultiMapCompute(OpKernelContext* context, const Tensor& input, next_idx = idx_vec(cur_idx); idx_vec(cur_idx) = offset; } - + auto it = extra_unique_id_map.find(iter->first); + if (it != extra_unique_id_map.end()) { + counter_map->emplace(offset, it->second); + } ++offset; } }; @@ -618,8 +676,9 @@ void MultipleElements(OpKernelContext* context, const Tensor& input, } template -void CheckCountOutput(OpKernelContext* context, Tensor* output_counter, - Tensor* idx, int num_outputs, int64 uniq_size) { +void CheckCountOutput(OpKernelContext* context, Tensor* output, Tensor* output_counter, + Tensor* idx, int num_outputs, int64 uniq_size, + int num_sparse, google::dense_hash_map counter_map) { if (num_outputs > 2) { auto idx_vec = idx->template vec(); AllocatorAttributes attr; @@ -633,12 +692,19 @@ void CheckCountOutput(OpKernelContext* context, Tensor* output_counter, for (int64 i = 0; i < N; ++i) { count_output_vec(idx_vec(i))++; } + if (num_sparse > 0) { + for (auto& it: counter_map) { + count_output_vec(it.first) += (it.second - 1); + } + } } + } template void ComputeInternalWithHashMap(OpKernelContext* context, const Tensor& input, - Tensor* idx, int64 axis, int64* uniq_size, int64 N, bool serial, Tensor* output) { + Tensor* idx, int64 axis, int64* uniq_size, int64 N, int num_sparse, bool serial, + google::dense_hash_map* counter_map, Tensor* output) { OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), errors::InvalidArgument("unique expects a 1D vector.")); // TODO(dga): Make unique polymorphic for returning int32 and int64 @@ -651,10 +717,10 @@ void ComputeInternalWithHashMap(OpKernelContext* context, const Tensor& input, if (N >= kPartitionLimit && !serial) { ParallelComputeV1 - (context, input, idx, axis, uniq_size, output); + (context, input, idx, axis, uniq_size, num_sparse, counter_map, output); } else { SerialComputeV1 - (context, input, idx, axis, uniq_size, output); + (context, input, idx, axis, uniq_size, num_sparse, counter_map, output); } } @@ -662,7 +728,7 @@ template void UniqueInternal(OpKernelContext* context, const Tensor& input, Tensor* idx, Tensor* output, Tensor* output_counter, int num_outputs, int64 partition_size, bool serial, int64 axis, int64 unique_ratio_hint, - std::vector& new_sizes, UniqueMaps map_flag) { + std::vector& new_sizes, UniqueMaps map_flag, int num_sparse = 0) { typedef google::dense_hash_map DefaultHashMap; AllocatorAttributes attr; @@ -672,6 +738,7 @@ void UniqueInternal(OpKernelContext* context, const Tensor& input, TensorShape({new_sizes[1]}), idx, attr)); int64 uniq_size_out; + google::dense_hash_map counter_map; if (new_sizes[0] == 1 && new_sizes[2] == 1) { // Specialized and faster implementation when unique is run over single @@ -687,33 +754,34 @@ void UniqueInternal(OpKernelContext* context, const Tensor& input, case MULTIMAP: if (num_buckets > 1 && !serial) { MultiMapCompute> - (context, input, idx, axis, &uniq_size_out, num_buckets, unique_ratio_hint, output); + (context, input, idx, axis, &uniq_size_out, num_buckets, unique_ratio_hint, num_sparse, &counter_map, output); } else { SerialComputeV1 - (context, input, idx, axis, &uniq_size_out, output); + (context, input, idx, axis, &uniq_size_out, num_sparse, &counter_map, output); } break; case STL: ComputeInternalWithHashMap> - (context, input, idx, axis, &uniq_size_out, N, serial, output); + (context, input, idx, axis, &uniq_size_out, N, num_sparse, serial, &counter_map, output); break; case ABSL: ComputeInternalWithHashMap> - (context, input, idx, axis, &uniq_size_out, N, serial, output); + (context, input, idx, axis, &uniq_size_out, N, num_sparse, serial, &counter_map, output); break; case GOOGLE: ComputeInternalWithHashMap - (context, input, idx, axis, &uniq_size_out, N, serial, output); + (context, input, idx, axis, &uniq_size_out, N, num_sparse, serial, &counter_map, output); break; default: ComputeInternalWithHashMap - (context, input, idx, axis, &uniq_size_out, N, serial, output); + (context, input, idx, axis, &uniq_size_out, N, num_sparse, serial, &counter_map, output); } } else { MultipleElements(context, input, idx, output, &uniq_size_out, axis, new_sizes); } - CheckCountOutput(context, output_counter, idx, num_outputs, uniq_size_out); + CheckCountOutput(context, output, output_counter, idx, num_outputs, + uniq_size_out, num_sparse, counter_map); } template @@ -743,6 +811,20 @@ void UniqueWithAxis(OpKernelContext* context, const Tensor& input, axis, unique_ratio_hint, new_sizes, map_flag); } +template +void UniqueWithExtraCounts(OpKernelContext* context, const Tensor& input, + Tensor* idx, Tensor* output, Tensor* output_counter, int num_outputs, + int64 partition_size, bool serial, int64 unique_ratio_hint, + int num_sparse, UniqueMaps map_flag) { + int64 axis = 0; + std::vector new_sizes{1, input.NumElements(), 1}; + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector.")); + UniqueInternal(context, input, idx, output, + output_counter, num_outputs, partition_size, serial, + axis, unique_ratio_hint, new_sizes, map_flag, num_sparse); +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_UNIQUE_ALI_OP_UTIL_H_ diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 27f6811fcff..306026977ef 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1741,6 +1741,26 @@ REGISTER_OP("UniqueWithCountsV2") return Status::OK(); }); +// --------------------------------------------------- + +REGISTER_OP("UniqueWithExtraCounts") + .Input("x: T") + .Input("extra_indices: N * T") + .Input("extra_counts: N * out_idx") + .Output("y: T") + .Output("idx: out_idx") + .Output("count: out_idx") + .Attr("T: type") + .Attr("N: int >= 0") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + auto uniq = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, uniq); + c->set_output(1, c->input(0)); + c->set_output(2, uniq); + return Status::OK(); + }); + namespace { Status ShapeShapeFn(InferenceContext* c) { diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py index 9ec0ff74e3e..25c1b26a103 100644 --- a/tensorflow/python/kernel_tests/unique_op_test.py +++ b/tensorflow/python/kernel_tests/unique_op_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -278,6 +279,73 @@ def testUniqueWithCountsAbslMap(self): def testUniqueWithCountsDenseHashMap(self): self.RunUniqueWithCountsWithDifferentMaps('GOOGLE') +class UniqueWithExtraCountsTest(test.TestCase): + + def testInt32(self): + x = np.random.randint(2, high=1000, size=700000) + extra_x = x[:5].tolist() + extra_x_tensor = [constant_op.constant(extra_x, dtypes.int64)] + extra_count = [500 for _ in range(5)] + extra_count_tensor = [constant_op.constant(extra_count, dtypes.int32)] + with self.cached_session() as sess: + y, idx, count = array_ops.unique_with_extra_counts(x, extra_x_tensor, extra_count_tensor) + tf_y, tf_idx, tf_count = sess.run([y, idx, count]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + for value, count in zip(tf_y, tf_count): + if value in extra_x: + self.assertEqual(count, np.sum(x == value) + 499) + else: + self.assertEqual(count, np.sum(x == value)) + + def testInt32OutIdxInt64(self): + x = np.random.randint(2, high=1000, size=700000) + extra_x = x[:5].tolist() + extra_x_tensor = [constant_op.constant(extra_x, dtypes.int64)] + extra_count = [500 for _ in range(5)] + extra_count_tensor = [constant_op.constant(extra_count, dtypes.int64)] + with self.cached_session() as sess: + y, idx, count = array_ops.unique_with_extra_counts(x, extra_x_tensor, extra_count_tensor) + tf_y, tf_idx, tf_count = sess.run([y, idx, count]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + for value, count in zip(tf_y, tf_count): + if value in extra_x: + self.assertEqual(count, np.sum(x == value) + 499) + else: + self.assertEqual(count, np.sum(x == value)) + + def RunUniqueWithCountsWithDifferentMaps(self, map_type): + recover_env = False + if 'DEEPREC_UNIQUE_OP_HASH_MAP' in os.environ: + recover_env = True + old_env = os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] + + os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] = map_type + self.testInt32() + self.testInt32OutIdxInt64() + + del os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] + if recover_env: + os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] = old_env + + def testUniqueWithCountsMultiMap(self): + self.RunUniqueWithCountsWithDifferentMaps('MULTIMAP') + + def testUniqueWithCountsStlMap(self): + self.RunUniqueWithCountsWithDifferentMaps('STL') + + def testUniqueWithCountsAbslMap(self): + self.RunUniqueWithCountsWithDifferentMaps('ABSL') + + def testUniqueWithCountsDenseHashMap(self): + self.RunUniqueWithCountsWithDifferentMaps('GOOGLE') if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index 81b315e2e43..dbf254d5f14 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -19,6 +19,7 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.framework import constant_op from tensorflow.python.ops import string_ops from tensorflow.python.ops.check_ops import assert_equal from tensorflow.python.platform import googletest @@ -2871,6 +2872,39 @@ def testCountsTensor(self): value = checkpoint_utils.load_variable(ckpt_path, name) self.assertAllEqual(value, [3, 3, 1, 3, 2]) + def testCountsWithSparseAndDenseTensor(self): + os.environ["TF_RECORD_FREQ"] = "1" + checkpoint_directory = self.get_temp_dir() + ckpt_path = os.path.join(checkpoint_directory, "model.ckpt") + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + sp1 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([0,0,0,1,1,2], dtypes.int64), + dense_shape=[6, 1]) + ids = constant_op.constant([3,3,3,4,4,1], dtype=dtypes.int64) + emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None) + emb2 = embedding_ops.embedding_lookup(var, ids) + emb = emb1 + emb2 + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + sess.run([init]) + sess.run(train_op) + saver.save(sess, ckpt_path) + + for name, shape in checkpoint_utils.list_variables(ckpt_path): + if name == "var_1-freqs": + value = checkpoint_utils.load_variable(ckpt_path, name) + self.assertAllEqual(value, [3, 3, 1, 3, 2]) + def testCountsTensorWithGradientDescent(self): os.environ["TF_RECORD_FREQ"] = "1" checkpoint_directory = self.get_temp_dir() @@ -2908,6 +2942,41 @@ def testCountsTensorWithGradientDescent(self): self.assertAllEqual(value, [3, 3, 1, 3, 2]) del os.environ["TF_RECORD_FREQ"] + + def testCountsDenseAndSparseTensorWithGradientDescent(self): + os.environ["TF_RECORD_FREQ"] = "1" + checkpoint_directory = self.get_temp_dir() + ckpt_path = os.path.join(checkpoint_directory, "model.ckpt") + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + sp1 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([0,0,0,1,1,2], dtypes.int64), + dense_shape=[6, 1]) + ids = constant_op.constant([3,3,3,4,4,1], dtype=dtypes.int64) + emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None) + emb2 = embedding_ops.embedding_lookup(var, ids) + emb = emb1 + emb2 + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = gradient_descent.GradientDescentOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + sess.run([init]) + sess.run(train_op) + saver.save(sess, ckpt_path) + + for name, shape in checkpoint_utils.list_variables(ckpt_path): + if name == "var_1-freqs": + value = checkpoint_utils.load_variable(ckpt_path, name) + self.assertAllEqual(value, [3, 3, 1, 3, 2]) + + del os.environ["TF_RECORD_FREQ"] if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 799e3c5f5bd..d2576cf2800 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -19,9 +19,11 @@ from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.framework import dtypes from tensorflow.python.ops import gen_hash_training_ops from tensorflow.python.ops import kv_variable_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops @@ -72,22 +74,28 @@ def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): if isinstance(handle, kv_variable_ops.EmbeddingVariable): global_step = training_util.get_or_create_global_step() if handle.need_counts() and len(handle._counts_tensor.keys()) != 0: + extra_counts, extra_indices = [], [] if indices.op.type == "ConcatV2": - total_counts = [] for tensor in indices.op.inputs: if tensor.op.type == "Reshape": indices_tensor = tensor.op.inputs[0] - total_counts.append(handle._counts_tensor[indices_tensor]) - from tensorflow.python.ops import array_ops - counts_tensor = array_ops.concat(total_counts, 0) + if indices_tensor in handle._counts_tensor: + extra_counts.append(handle._counts_tensor[indices_tensor]) + extra_indices.append(indices_tensor) elif indices.op.type == "Reshape": indices_tensor = indices.op.inputs[0] - counts_tensor = handle._counts_tensor[indices_tensor] + if indices_tensor in handle._counts_tensor: + extra_counts.append(handle._counts_tensor[indices_tensor]) + extra_indices.append(indices_tensor) + unique_indices, new_index_positions, indices_counts = \ + array_ops.unique_with_extra_counts(indices, extra_indices, extra_counts) + summed_grads = math_ops.unsorted_segment_sum( + grad, new_index_positions, array_ops.shape(unique_indices)[0]) return training_ops.kv_resource_sparse_apply_gradient_descent_with_counts( handle.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), - grad, indices, global_step, - counts_tensor, use_locking=self._use_locking) + summed_grads, unique_indices, global_step, + indices_counts, use_locking=self._use_locking) else: return training_ops.kv_resource_sparse_apply_gradient_descent( handle.handle, math_ops.cast(self._learning_rate_tensor, diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 7523604ccf9..d55d1b503d6 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -93,16 +93,14 @@ def _deduplicate_indexed_slices_with_counts(values, indices): array_ops.shape(unique_indices)[0]) return (summed_values, unique_indices, indices_counts) -def _deduplicate_indexed_slices_with_counts_reduction(values, indices, counts): +def _deduplicate_indexed_slices_with_counts_reduction(values, indices, extra_counts, extra_indices): """Sums `values` associated with any non-unique `indices` and return counts of each count in `values`.""" - unique_indices, new_index_positions = array_ops.unique(indices) + unique_indices, new_index_positions, summed_counts = \ + array_ops.unique_with_extra_counts(indices, extra_indices, extra_counts) summed_values = math_ops.unsorted_segment_sum( values, new_index_positions, array_ops.shape(unique_indices)[0]) - summed_counts = math_ops.unsorted_segment_sum( - counts, new_index_positions, - array_ops.shape(unique_indices)[0]) return (summed_values, unique_indices, summed_counts) def _var_key(var): @@ -1105,19 +1103,22 @@ def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): _deduplicate_indexed_slices_with_counts( values=grad, indices=indices) else: + extra_counts, extra_indices = [], [] if indices.op.type == "ConcatV2": - total_counts = [] for tensor in indices.op.inputs: if tensor.op.type == "Reshape": indices_tensor = tensor.op.inputs[0] - total_counts.append(handle._counts_tensor[indices_tensor]) - counts_tensor = array_ops.concat(total_counts, 0) + if indices_tensor in handle._counts_tensor: + extra_counts.append(handle._counts_tensor[indices_tensor]) + extra_indices.append(indices_tensor) elif indices.op.type == "Reshape": indices_tensor = indices.op.inputs[0] - counts_tensor = handle._counts_tensor[indices_tensor] + if indices_tensor in handle._counts_tensor: + extra_counts.append(handle._counts_tensor[indices_tensor]) + extra_indices.append(indices_tensor) summed_grad, unique_indices, indices_counts = \ _deduplicate_indexed_slices_with_counts_reduction( - grad, indices, counts_tensor) + grad, indices, extra_counts, extra_indices) return self._resource_apply_sparse( summed_grad, handle, unique_indices, indices_counts) else: