Skip to content

Commit

Permalink
initial fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wangleis committed Sep 10, 2024
1 parent a87851d commit 9956420
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 102 deletions.
202 changes: 109 additions & 93 deletions src/plugins/intel_cpu/src/cpu_streams_calculation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ std::vector<std::vector<int>> get_streams_info_table(const int input_streams,
const std::vector<std::vector<int>>& one_proc_table,
const int num_threads,
const IStreamsExecutor::Config::StreamsMode sub_streams_model) {
stream_info[THREADS_PER_STREAM] = sub_streams_model == IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL
? num_threads
: std::min(TP_CPU_LIMIT, num_threads);
if ((one_proc_info[PROC_NUMA_NODE_ID] < 0) || (one_proc_info[PROC_SOCKET_ID] < 0) ||
(((one_proc_info[MAIN_CORE_PROC] > 0) &&
(one_proc_info[MAIN_CORE_PROC] < stream_info[THREADS_PER_STREAM])) ||
Expand Down Expand Up @@ -178,7 +175,11 @@ std::vector<std::vector<int>> get_streams_info_table(const int input_streams,
std::unordered_set<int> socket_id_list(proc_type_table.size());
for (size_t i = 1; i < proc_type_table.size(); i++) {
if (!socket_id_list.count(proc_type_table[i][PROC_SOCKET_ID])) {
proc_socket_table.push_back(proc_type_table[i]);
if (proc_type_table[i][PROC_SOCKET_ID] == input_current_socket_id) {
proc_socket_table.insert(proc_socket_table.begin(), proc_type_table[i]);
} else {
proc_socket_table.push_back(proc_type_table[i]);
}
socket_id_list.insert(proc_type_table[i][PROC_SOCKET_ID]);
} else {
for (auto& row : proc_socket_table) {
Expand Down Expand Up @@ -333,58 +334,28 @@ std::vector<std::vector<int>> get_streams_info_table(const int input_streams,
int total_streams = n_streams;

if (stream_info[PROC_TYPE] == INIT_VAL) {
stream_info[THREADS_PER_STREAM] = n_threads_per_stream;

for (int n_type = MAIN_CORE_PROC; (n_type <= HYPER_THREADING_PROC) && (n_streams > 0); n_type++) {
if (proc_type_table.size() == 1) {
if (proc_type_table[0][n_type] >= stream_info[THREADS_PER_STREAM]) {
update_streams_per_node(n_type, proc_type_table[0]);
}
} else {
for (size_t n_node = 1; (n_node < proc_type_table.size()) && (n_streams > 0); n_node++) {
if ((proc_type_table[n_node][n_type] >= stream_info[THREADS_PER_STREAM]) &&
((current_socket_id < 0) || (proc_type_table[n_node][PROC_SOCKET_ID] == current_socket_id))) {
update_streams_per_node(n_type, proc_type_table[n_node]);
}
}
}
}

if (total_streams == n_streams) {
if (proc_type_table.size() == 1) {
if (proc_type_table[0][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) {
update_mix_stream_info(proc_type_table[0],
proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);
n_streams--;
}
} else {
for (size_t n_node = 0; (n_node < proc_socket_table.size()) && (n_streams > 0); n_node++) {
if ((proc_socket_table[n_node][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) &&
((current_socket_id < 0) || (proc_socket_table[n_node][PROC_SOCKET_ID] == current_socket_id))) {
update_mix_stream_info(proc_socket_table[n_node],
proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);
n_streams--;
if ((n_streams == 1) && (proc_type_table.size() > 1) &&
((hint_model_distribution_policy.find(ov::hint::ModelDistributionPolicy::TENSOR_PARALLEL) !=
hint_model_distribution_policy.end()))) {
for (auto& row : proc_socket_table) {
stream_info[THREADS_PER_STREAM] = std::min(TP_CPU_LIMIT, n_threads_per_stream);
for (size_t i = 1; i < proc_type_table.size(); i++) {
if ((proc_type_table[i][PROC_SOCKET_ID] == row[PROC_SOCKET_ID]) &&
(proc_type_table[i][MAIN_CORE_PROC] >= stream_info[THREADS_PER_STREAM])) {
create_one_stream(proc_type_table[i],
{proc_type_table[i]},
stream_info[THREADS_PER_STREAM],
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET);
break;
}
}
}
}

if (total_streams == n_streams) {
create_one_stream(proc_socket_table[current_socket_id],
proc_type_table,
proc_socket_table[current_socket_id][ALL_PROC],
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET);
for (size_t n_node = 0; n_node < proc_socket_table.size(); n_node++) {
if (n_node != size_t(current_socket_id)) {
create_one_stream(proc_socket_table[n_node],
if (stream_info[STREAM_SOCKET_ID] == row[PROC_SOCKET_ID]) {
continue;
} else {
stream_info[THREADS_PER_STREAM] = std::min(stream_info[THREADS_PER_STREAM], row[ALL_PROC]);
create_one_stream(row,
proc_type_table,
proc_socket_table[n_node][ALL_PROC],
stream_info[THREADS_PER_STREAM],
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_FOR_SOCKET);
}
}
Expand All @@ -405,57 +376,102 @@ std::vector<std::vector<int>> get_streams_info_table(const int input_streams,
}
}
streams_info_table.insert(streams_info_table.begin(), stream_info);
n_streams--;
}
} else {
stream_info[THREADS_PER_STREAM] = n_threads_per_stream;

if (n_streams > 0) {
std::vector<std::vector<int>> remain_proc_type_table(proc_type_table);
size_t stream_table_size = streams_info_table.size();

for (size_t i = 0; i < stream_table_size; i++) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] >= 0) &&
(streams_info_table[i][STREAM_SOCKET_ID] >= 0)) {
for (auto& row : remain_proc_type_table) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) &&
(streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) {
row[streams_info_table[i][PROC_TYPE]] -= (streams_info_table[i][NUMBER_OF_STREAMS] == 0
? 1
: streams_info_table[i][NUMBER_OF_STREAMS]) *
streams_info_table[i][THREADS_PER_STREAM];
for (int n_type = MAIN_CORE_PROC; (n_type <= HYPER_THREADING_PROC) && (n_streams > 0); n_type++) {
if (proc_type_table.size() == 1) {
if (proc_type_table[0][n_type] >= stream_info[THREADS_PER_STREAM]) {
update_streams_per_node(n_type, proc_type_table[0]);
}
} else {
for (size_t n_node = 1; (n_node < proc_type_table.size()) && (n_streams > 0); n_node++) {
if ((proc_type_table[n_node][n_type] >= stream_info[THREADS_PER_STREAM]) &&
((current_socket_id < 0) ||
(proc_type_table[n_node][PROC_SOCKET_ID] == current_socket_id))) {
update_streams_per_node(n_type, proc_type_table[n_node]);
}
}
}
}

while (n_streams > 0) {
update_mix_stream_info(proc_type_table[0],
remain_proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);
if (total_streams == n_streams) {
if (proc_type_table.size() == 1) {
if (proc_type_table[0][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) {
update_mix_stream_info(proc_type_table[0],
proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);
n_streams--;
}
} else {
for (size_t n_node = 0; (n_node < proc_socket_table.size()) && (n_streams > 0); n_node++) {
if ((proc_socket_table[n_node][ALL_PROC] >= stream_info[THREADS_PER_STREAM]) &&
((current_socket_id < 0) ||
(proc_socket_table[n_node][PROC_SOCKET_ID] == current_socket_id))) {
update_mix_stream_info(proc_socket_table[n_node],
proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);
n_streams--;
}
}
}
}

if (stream_table_size == streams_info_table.size()) {
break;
if (n_streams > 0) {
std::vector<std::vector<int>> remain_proc_type_table(proc_type_table);
size_t stream_table_size = streams_info_table.size();

for (size_t i = 0; i < stream_table_size; i++) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] >= 0) &&
(streams_info_table[i][STREAM_SOCKET_ID] >= 0)) {
for (auto& row : remain_proc_type_table) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) &&
(streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) {
row[streams_info_table[i][PROC_TYPE]] -=
(streams_info_table[i][NUMBER_OF_STREAMS] == 0
? 1
: streams_info_table[i][NUMBER_OF_STREAMS]) *
streams_info_table[i][THREADS_PER_STREAM];
}
}
}
}
n_streams--;
int numa_node_id = streams_info_table[stream_table_size + 1][STREAM_NUMA_NODE_ID];
int socket_id = streams_info_table[stream_table_size + 1][STREAM_SOCKET_ID];
for (size_t i = stream_table_size + 1; i < streams_info_table.size(); i++) {
numa_node_id = numa_node_id == streams_info_table[i][STREAM_NUMA_NODE_ID] ? numa_node_id : -1;
socket_id = socket_id == streams_info_table[i][STREAM_SOCKET_ID] ? socket_id : -1;
for (auto& row : remain_proc_type_table) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) &&
(streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) {
row[streams_info_table[i][PROC_TYPE]] -= (streams_info_table[i][NUMBER_OF_STREAMS] == 0
? 1
: streams_info_table[i][NUMBER_OF_STREAMS]) *
streams_info_table[i][THREADS_PER_STREAM];

while (n_streams > 0) {
update_mix_stream_info(proc_type_table[0],
remain_proc_type_table,
n_threads_per_stream,
IStreamsExecutor::Config::StreamsMode::SUB_STREAMS_NULL,
ALL_PROC);

if (stream_table_size == streams_info_table.size()) {
break;
}
n_streams--;
int numa_node_id = streams_info_table[stream_table_size + 1][STREAM_NUMA_NODE_ID];
int socket_id = streams_info_table[stream_table_size + 1][STREAM_SOCKET_ID];
for (size_t i = stream_table_size + 1; i < streams_info_table.size(); i++) {
numa_node_id = numa_node_id == streams_info_table[i][STREAM_NUMA_NODE_ID] ? numa_node_id : -1;
socket_id = socket_id == streams_info_table[i][STREAM_SOCKET_ID] ? socket_id : -1;
for (auto& row : remain_proc_type_table) {
if ((streams_info_table[i][STREAM_NUMA_NODE_ID] == row[PROC_NUMA_NODE_ID]) &&
(streams_info_table[i][STREAM_SOCKET_ID] == row[PROC_SOCKET_ID])) {
row[streams_info_table[i][PROC_TYPE]] -=
(streams_info_table[i][NUMBER_OF_STREAMS] == 0
? 1
: streams_info_table[i][NUMBER_OF_STREAMS]) *
streams_info_table[i][THREADS_PER_STREAM];
}
}
}
streams_info_table[stream_table_size][STREAM_NUMA_NODE_ID] = numa_node_id;
streams_info_table[stream_table_size][STREAM_SOCKET_ID] = socket_id;
stream_table_size = streams_info_table.size();
}
streams_info_table[stream_table_size][STREAM_NUMA_NODE_ID] = numa_node_id;
streams_info_table[stream_table_size][STREAM_SOCKET_ID] = socket_id;
stream_table_size = streams_info_table.size();
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2337,8 +2337,7 @@ StreamsCalculationTestCase _2sockets_mock_latency_25 = {
{40, 20, 0, 20, 2, 1},
{20, 10, 0, 10, 3, 1}},
{{1, ALL_PROC, 64, -1, -1},
{-1, ALL_PROC, 32, -1, 0},
{0, MAIN_CORE_PROC, 32, 0, 0},
{-1, MAIN_CORE_PROC, 32, 0, 0},
{-1, ALL_PROC, 32, -1, 1},
{0, MAIN_CORE_PROC, 20, 2, 1},
{0, MAIN_CORE_PROC, 10, 3, 1},
Expand Down Expand Up @@ -2399,8 +2398,7 @@ StreamsCalculationTestCase _2sockets_mock_latency_28 = {
{40, 20, 0, 20, 2, 1},
{20, 10, 0, 10, 3, 1}},
{{1, ALL_PROC, 64, -1, -1},
{-1, ALL_PROC, 32, -1, 0},
{0, MAIN_CORE_PROC, 32, 0, 0},
{-1, MAIN_CORE_PROC, 32, 0, 0},
{-1, ALL_PROC, 32, -1, 1},
{0, MAIN_CORE_PROC, 20, 2, 1},
{0, MAIN_CORE_PROC, 10, 3, 1},
Expand Down Expand Up @@ -2460,11 +2458,12 @@ StreamsCalculationTestCase _2sockets_mock_latency_31 = {
{60, 30, 0, 30, 1, 0},
{40, 20, 0, 20, 2, 1},
{20, 10, 0, 10, 3, 1}},
{{1, ALL_PROC, 140, -1, 0},
{0, MAIN_CORE_PROC, 40, 0, 0},
{0, MAIN_CORE_PROC, 30, 1, 0},
{0, HYPER_THREADING_PROC, 40, 0, 0},
{0, HYPER_THREADING_PROC, 30, 1, 0}},
{{1, ALL_PROC, 64, -1, -1},
{-1, MAIN_CORE_PROC, 32, 0, 0},
{-1, ALL_PROC, 32, -1, 1},
{0, MAIN_CORE_PROC, 20, 2, 1},
{0, MAIN_CORE_PROC, 10, 3, 1},
{0, HYPER_THREADING_PROC, 2, 2, 1}},
};
StreamsCalculationTestCase _2sockets_mock_latency_32 = {
1,
Expand Down Expand Up @@ -2634,6 +2633,22 @@ StreamsCalculationTestCase _2sockets_mock_latency_37 = {
{{48, 48, 0, 0, -1, -1}, {24, 24, 0, 0, 0, 0}, {24, 24, 0, 0, 1, 1}},
{{1, MAIN_CORE_PROC, 48, -1, -1}, {-1, MAIN_CORE_PROC, 24, 1, 1}, {-1, MAIN_CORE_PROC, 24, 0, 0}},
};
StreamsCalculationTestCase _2sockets_mock_latency_38 = {
1,
false,
0,
0,
0,
0,
"LATENCY",
{ov::hint::ModelDistributionPolicy::TENSOR_PARALLEL},
{{256, 128, 0, 128, 0, 0},
{64, 32, 0, 32, 0, 0},
{64, 32, 0, 32, 1, 0},
{64, 32, 0, 32, 2, 1},
{64, 32, 0, 32, 3, 1}},
{{1, MAIN_CORE_PROC, 64, -1, -1}, {-1, MAIN_CORE_PROC, 32, 0, 0}, {-1, MAIN_CORE_PROC, 32, 2, 1}},
};
TEST_P(StreamsCalculationTests, StreamsCalculation) {}

INSTANTIATE_TEST_SUITE_P(StreamsInfoTable,
Expand Down Expand Up @@ -2816,6 +2831,7 @@ INSTANTIATE_TEST_SUITE_P(StreamsInfoTable,
_2sockets_mock_latency_35,
_2sockets_mock_latency_36,
_2sockets_mock_latency_37,
_2sockets_mock_latency_38,
_1sockets_mock_latency_1,
_1sockets_mock_latency_2,
_1sockets_mock_latency_3,
Expand Down

0 comments on commit 9956420

Please sign in to comment.