From 9956420489a1962e66536d236cb6d5a82782758c Mon Sep 17 00:00:00 2001 From: "Shen, Wanglei" Date: Tue, 10 Sep 2024 23:33:23 +0800 Subject: [PATCH] initial fix --- .../intel_cpu/src/cpu_streams_calculation.cpp | 202 ++++++++++-------- .../streams_info/streams_info_table_test.cpp | 34 ++- 2 files changed, 134 insertions(+), 102 deletions(-) diff --git a/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp b/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp index fb8294e5ce29fd..c0e1e96547cec7 100644 --- a/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp +++ b/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp @@ -102,9 +102,6 @@ std::vector> get_streams_info_table(const int input_streams, const std::vector>& one_proc_table, const int num_threads, const IStreamsExecutor::Config::StreamsMode sub_streams_model) { - stream_info[THREADS_PER_STREAM] = sub_streams_model == IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL - ? num_threads - : std::min(TP_CPU_LIMIT, num_threads); if ((one_proc_info[PROC_NUMA_NODE_ID] < 0) || (one_proc_info[PROC_SOCKET_ID] < 0) || (((one_proc_info[MAIN_CORE_PROC] > 0) && (one_proc_info[MAIN_CORE_PROC] < stream_info[THREADS_PER_STREAM])) || @@ -178,7 +175,11 @@ std::vector> get_streams_info_table(const int input_streams, std::unordered_set socket_id_list(proc_type_table.size()); for (size_t i = 1; i < proc_type_table.size(); i++) { if (!socket_id_list.count(proc_type_table[i][PROC_SOCKET_ID])) { - proc_socket_table.push_back(proc_type_table[i]); + if (proc_type_table[i][PROC_SOCKET_ID] == input_current_socket_id) { + proc_socket_table.insert(proc_socket_table.begin(), proc_type_table[i]); + } else { + proc_socket_table.push_back(proc_type_table[i]); + } socket_id_list.insert(proc_type_table[i][PROC_SOCKET_ID]); } else { for (auto& row : proc_socket_table) { @@ -333,58 +334,28 @@ std::vector> get_streams_info_table(const int input_streams, int total_streams = n_streams; if (stream_info[PROC_TYPE] == INIT_VAL) { - stream_info[THREADS_PER_STREAM] = n_threads_per_stream; - - for (int n_type = MAIN_CORE_PROC; (n_type <= HYPER_THREADING_PROC) && (n_streams > 0); n_type++) { - if (proc_type_table.size() == 1) { - if (proc_type_table[0][n_type] >= stream_info[THREADS_PER_STREAM]) { - update_streams_per_node(n_type, proc_type_table[0]); - } - } else { - for (size_t n_node = 1; (n_node < proc_type_table.size()) && (n_streams > 0); n_node++) { - if ((proc_type_table[n_node][n_type] >= stream_info[THREADS_PER_STREAM]) && - ((current_socket_id < 0) || (proc_type_table[n_node][PROC_SOCKET_ID] == current_socket_id))) { - update_streams_per_node(n_type, proc_type_table[n_node]); - } - } - } - } - - if (total_streams == n_streams) { - if (proc_type_table.size() == 1) { - if (proc_type_table[0][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) { - update_mix_stream_info(proc_type_table[0], - proc_type_table, - n_threads_per_stream, - IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, - ALL_PROC); - n_streams--; - } - } else { - for (size_t n_node = 0; (n_node < proc_socket_table.size()) && (n_streams > 0); n_node++) { - if ((proc_socket_table[n_node][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) && - ((current_socket_id < 0) || (proc_socket_table[n_node][PROC_SOCKET_ID] == current_socket_id))) { - update_mix_stream_info(proc_socket_table[n_node], - proc_type_table, - n_threads_per_stream, - IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, - ALL_PROC); - n_streams--; + if ((n_streams == 1) && (proc_type_table.size() > 1) && + ((hint_model_distribution_policy.find(ov::hint::ModelDistributionPolicy::TENSOR_PARALLEL) != + hint_model_distribution_policy.end()))) { + for (auto& row : proc_socket_table) { + stream_info[THREADS_PER_STREAM] = std::min(TP_CPU_LIMIT, n_threads_per_stream); + for (size_t i = 1; i < proc_type_table.size(); i++) { + if ((proc_type_table[i][PROC_SOCKET_ID] == row[PROC_SOCKET_ID]) && + (proc_type_table[i][MAIN_CORE_PROC] >= stream_info[THREADS_PER_STREAM])) { + create_one_stream(proc_type_table[i], + {proc_type_table[i]}, + stream_info[THREADS_PER_STREAM], + IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET); + break; } } - } - } - - if (total_streams == n_streams) { - create_one_stream(proc_socket_table[current_socket_id], - proc_type_table, - proc_socket_table[current_socket_id][ALL_PROC], - IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET); - for (size_t n_node = 0; n_node < proc_socket_table.size(); n_node++) { - if (n_node != size_t(current_socket_id)) { - create_one_stream(proc_socket_table[n_node], + if (stream_info[STREAM_SOCKET_ID] == row[PROC_SOCKET_ID]) { + continue; + } else { + stream_info[THREADS_PER_STREAM] = std::min(stream_info[THREADS_PER_STREAM], row[ALL_PROC]); + create_one_stream(row, proc_type_table, - proc_socket_table[n_node][ALL_PROC], + stream_info[THREADS_PER_STREAM], IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET); } } @@ -405,57 +376,102 @@ std::vector> get_streams_info_table(const int input_streams, } } streams_info_table.insert(streams_info_table.begin(), stream_info); - n_streams--; - } + } else { + stream_info[THREADS_PER_STREAM] = n_threads_per_stream; - if (n_streams > 0) { - std::vector> remain_proc_type_table(proc_type_table); - size_t stream_table_size = streams_info_table.size(); - - for (size_t i = 0; i < stream_table_size; i++) { - if ((streams_info_table[i][STREAM_NUMA_NODE_ID] >= 0) && - (streams_info_table[i][STREAM_SOCKET_ID] >= 0)) { - for (auto& row : remain_proc_type_table) { - if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) && - (streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) { - row[streams_info_table[i][PROC_TYPE]] -= (streams_info_table[i][NUMBER_OF_STREAMS] == 0 - ? 1 - : streams_info_table[i][NUMBER_OF_STREAMS]) * - streams_info_table[i][THREADS_PER_STREAM]; + for (int n_type = MAIN_CORE_PROC; (n_type <= HYPER_THREADING_PROC) && (n_streams > 0); n_type++) { + if (proc_type_table.size() == 1) { + if (proc_type_table[0][n_type] >= stream_info[THREADS_PER_STREAM]) { + update_streams_per_node(n_type, proc_type_table[0]); + } + } else { + for (size_t n_node = 1; (n_node < proc_type_table.size()) && (n_streams > 0); n_node++) { + if ((proc_type_table[n_node][n_type] >= stream_info[THREADS_PER_STREAM]) && + ((current_socket_id < 0) || + (proc_type_table[n_node][PROC_SOCKET_ID] == current_socket_id))) { + update_streams_per_node(n_type, proc_type_table[n_node]); } } } } - while (n_streams > 0) { - update_mix_stream_info(proc_type_table[0], - remain_proc_type_table, - n_threads_per_stream, - IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, - ALL_PROC); + if (total_streams == n_streams) { + if (proc_type_table.size() == 1) { + if (proc_type_table[0][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) { + update_mix_stream_info(proc_type_table[0], + proc_type_table, + n_threads_per_stream, + IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, + ALL_PROC); + n_streams--; + } + } else { + for (size_t n_node = 0; (n_node < proc_socket_table.size()) && (n_streams > 0); n_node++) { + if ((proc_socket_table[n_node][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) && + ((current_socket_id < 0) || + (proc_socket_table[n_node][PROC_SOCKET_ID] == current_socket_id))) { + update_mix_stream_info(proc_socket_table[n_node], + proc_type_table, + n_threads_per_stream, + IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, + ALL_PROC); + n_streams--; + } + } + } + } - if (stream_table_size == streams_info_table.size()) { - break; + if (n_streams > 0) { + std::vector> remain_proc_type_table(proc_type_table); + size_t stream_table_size = streams_info_table.size(); + + for (size_t i = 0; i < stream_table_size; i++) { + if ((streams_info_table[i][STREAM_NUMA_NODE_ID] >= 0) && + (streams_info_table[i][STREAM_SOCKET_ID] >= 0)) { + for (auto& row : remain_proc_type_table) { + if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) && + (streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) { + row[streams_info_table[i][PROC_TYPE]] -= + (streams_info_table[i][NUMBER_OF_STREAMS] == 0 + ? 1 + : streams_info_table[i][NUMBER_OF_STREAMS]) * + streams_info_table[i][THREADS_PER_STREAM]; + } + } + } } - n_streams--; - int numa_node_id = streams_info_table[stream_table_size + 1][STREAM_NUMA_NODE_ID]; - int socket_id = streams_info_table[stream_table_size + 1][STREAM_SOCKET_ID]; - for (size_t i = stream_table_size + 1; i < streams_info_table.size(); i++) { - numa_node_id = numa_node_id == streams_info_table[i][STREAM_NUMA_NODE_ID] ? numa_node_id : -1; - socket_id = socket_id == streams_info_table[i][STREAM_SOCKET_ID] ? socket_id : -1; - for (auto& row : remain_proc_type_table) { - if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) && - (streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) { - row[streams_info_table[i][PROC_TYPE]] -= (streams_info_table[i][NUMBER_OF_STREAMS] == 0 - ? 1 - : streams_info_table[i][NUMBER_OF_STREAMS]) * - streams_info_table[i][THREADS_PER_STREAM]; + + while (n_streams > 0) { + update_mix_stream_info(proc_type_table[0], + remain_proc_type_table, + n_threads_per_stream, + IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL, + ALL_PROC); + + if (stream_table_size == streams_info_table.size()) { + break; + } + n_streams--; + int numa_node_id = streams_info_table[stream_table_size + 1][STREAM_NUMA_NODE_ID]; + int socket_id = streams_info_table[stream_table_size + 1][STREAM_SOCKET_ID]; + for (size_t i = stream_table_size + 1; i < streams_info_table.size(); i++) { + numa_node_id = numa_node_id == streams_info_table[i][STREAM_NUMA_NODE_ID] ? numa_node_id : -1; + socket_id = socket_id == streams_info_table[i][STREAM_SOCKET_ID] ? socket_id : -1; + for (auto& row : remain_proc_type_table) { + if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) && + (streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) { + row[streams_info_table[i][PROC_TYPE]] -= + (streams_info_table[i][NUMBER_OF_STREAMS] == 0 + ? 1 + : streams_info_table[i][NUMBER_OF_STREAMS]) * + streams_info_table[i][THREADS_PER_STREAM]; + } } } + streams_info_table[stream_table_size][STREAM_NUMA_NODE_ID] = numa_node_id; + streams_info_table[stream_table_size][STREAM_SOCKET_ID] = socket_id; + stream_table_size = streams_info_table.size(); } - streams_info_table[stream_table_size][STREAM_NUMA_NODE_ID] = numa_node_id; - streams_info_table[stream_table_size][STREAM_SOCKET_ID] = socket_id; - stream_table_size = streams_info_table.size(); } } } else { diff --git a/src/plugins/intel_cpu/tests/unit/streams_info/streams_info_table_test.cpp b/src/plugins/intel_cpu/tests/unit/streams_info/streams_info_table_test.cpp index 4cc795eeaf2910..93bac90be95a04 100644 --- a/src/plugins/intel_cpu/tests/unit/streams_info/streams_info_table_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/streams_info/streams_info_table_test.cpp @@ -2337,8 +2337,7 @@ StreamsCalculationTestCase _2sockets_mock_latency_25 = { {40, 20, 0, 20, 2, 1}, {20, 10, 0, 10, 3, 1}}, {{1, ALL_PROC, 64, -1, -1}, - {-1, ALL_PROC, 32, -1, 0}, - {0, MAIN_CORE_PROC, 32, 0, 0}, + {-1, MAIN_CORE_PROC, 32, 0, 0}, {-1, ALL_PROC, 32, -1, 1}, {0, MAIN_CORE_PROC, 20, 2, 1}, {0, MAIN_CORE_PROC, 10, 3, 1}, @@ -2399,8 +2398,7 @@ StreamsCalculationTestCase _2sockets_mock_latency_28 = { {40, 20, 0, 20, 2, 1}, {20, 10, 0, 10, 3, 1}}, {{1, ALL_PROC, 64, -1, -1}, - {-1, ALL_PROC, 32, -1, 0}, - {0, MAIN_CORE_PROC, 32, 0, 0}, + {-1, MAIN_CORE_PROC, 32, 0, 0}, {-1, ALL_PROC, 32, -1, 1}, {0, MAIN_CORE_PROC, 20, 2, 1}, {0, MAIN_CORE_PROC, 10, 3, 1}, @@ -2460,11 +2458,12 @@ StreamsCalculationTestCase _2sockets_mock_latency_31 = { {60, 30, 0, 30, 1, 0}, {40, 20, 0, 20, 2, 1}, {20, 10, 0, 10, 3, 1}}, - {{1, ALL_PROC, 140, -1, 0}, - {0, MAIN_CORE_PROC, 40, 0, 0}, - {0, MAIN_CORE_PROC, 30, 1, 0}, - {0, HYPER_THREADING_PROC, 40, 0, 0}, - {0, HYPER_THREADING_PROC, 30, 1, 0}}, + {{1, ALL_PROC, 64, -1, -1}, + {-1, MAIN_CORE_PROC, 32, 0, 0}, + {-1, ALL_PROC, 32, -1, 1}, + {0, MAIN_CORE_PROC, 20, 2, 1}, + {0, MAIN_CORE_PROC, 10, 3, 1}, + {0, HYPER_THREADING_PROC, 2, 2, 1}}, }; StreamsCalculationTestCase _2sockets_mock_latency_32 = { 1, @@ -2634,6 +2633,22 @@ StreamsCalculationTestCase _2sockets_mock_latency_37 = { {{48, 48, 0, 0, -1, -1}, {24, 24, 0, 0, 0, 0}, {24, 24, 0, 0, 1, 1}}, {{1, MAIN_CORE_PROC, 48, -1, -1}, {-1, MAIN_CORE_PROC, 24, 1, 1}, {-1, MAIN_CORE_PROC, 24, 0, 0}}, }; +StreamsCalculationTestCase _2sockets_mock_latency_38 = { + 1, + false, + 0, + 0, + 0, + 0, + "LATENCY", + {ov::hint::ModelDistributionPolicy::TENSOR_PARALLEL}, + {{256, 128, 0, 128, 0, 0}, + {64, 32, 0, 32, 0, 0}, + {64, 32, 0, 32, 1, 0}, + {64, 32, 0, 32, 2, 1}, + {64, 32, 0, 32, 3, 1}}, + {{1, MAIN_CORE_PROC, 64, -1, -1}, {-1, MAIN_CORE_PROC, 32, 0, 0}, {-1, MAIN_CORE_PROC, 32, 2, 1}}, +}; TEST_P(StreamsCalculationTests, StreamsCalculation) {} INSTANTIATE_TEST_SUITE_P(StreamsInfoTable, @@ -2816,6 +2831,7 @@ INSTANTIATE_TEST_SUITE_P(StreamsInfoTable, _2sockets_mock_latency_35, _2sockets_mock_latency_36, _2sockets_mock_latency_37, + _2sockets_mock_latency_38, _1sockets_mock_latency_1, _1sockets_mock_latency_2, _1sockets_mock_latency_3,