Skip to content

Commit

Permalink
add EmbeddingSpMDM8BitBenchmarkOutTypeFloat16 (#2952)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#70

Pull Request resolved: #2952

Add EmbeddingSpMDM8BitBenchmarkOutTypeFloat16 on ARM.

Reviewed By: sryap

Differential Revision: D60972344
  • Loading branch information
helloguo authored and facebook-github-bot committed Aug 14, 2024
1 parent 3070f88 commit f1b6b7c
Showing 1 changed file with 77 additions and 18 deletions.
95 changes: 77 additions & 18 deletions bench/EmbeddingSpMDM8BitBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ static vector<vector<int>> GetInputs_() {

vector<double> benchmarkTimes;

template <typename OutType>
int run_benchmark(
int batch_size,
int num_rows,
Expand All @@ -68,7 +69,8 @@ int run_benchmark(
bool normalize_by_lengths,
bool use_32_bit_indices = false,
bool prefetch = false,
bool stress_multi_threading = false) {
bool stress_multi_threading = false,
bool is_bf16_out = false) {
// Create embedding table
default_random_engine generator;
normal_distribution<float> embedding_distribution;
Expand Down Expand Up @@ -127,8 +129,8 @@ int run_benchmark(
weights[i] = embedding_distribution(generator);
}

vector<float> output_sls_ref(batch_size * embedding_dim);
vector<float> output_slws_ref(output_sls_ref.size()),
vector<OutType> output_sls_ref(batch_size * embedding_dim);
vector<OutType> output_slws_ref(output_sls_ref.size()),
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());

constexpr int NUM_WARMUP = 10;
Expand All @@ -149,7 +151,7 @@ int run_benchmark(
has_weight_options.push_back(true);
}
for (bool has_weight : has_weight_options) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
vector<OutType>& output_ref = has_weight ? output_slws_ref : output_sls_ref;

bool success = false, success_ref = false;

Expand Down Expand Up @@ -179,17 +181,19 @@ int run_benchmark(
output_ref.data());
}

vector<float>& output = has_weight ? output_slws : output_sls;
vector<OutType>& output = has_weight ? output_slws : output_sls;
vector<bool> flush_cache_options;
flush_cache_options.push_back(false);
if (!stress_multi_threading) {
flush_cache_options.push_back(true);
}

auto kernel_32 = GenerateEmbeddingSpMDM<uint8_t, int32_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_64 = GenerateEmbeddingSpMDM<uint8_t, int64_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_32 =
GenerateEmbeddingSpMDM<uint8_t, int32_t, std::int32_t, OutType>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_64 =
GenerateEmbeddingSpMDM<uint8_t, int64_t, std::int32_t, OutType>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);

#ifdef _OPENMP
#pragma omp barrier
Expand Down Expand Up @@ -255,9 +259,26 @@ int run_benchmark(
false && "ERROR: refernce impl and JIT imp did not both succeed");
} else if (success) {
for (size_t i = 0; i < output.size(); ++i) {
assert(fabs(output[i] - output_ref[i]) < 1e-3);
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
cout << i << " " << output[i] << " " << output_ref[i] << endl;
float tmp1 = 0;
float tmp2 = 0;
if (std::is_same<OutType, float>::value) {
tmp1 = output[i];
tmp2 = output_ref[i];
} else if (std::is_same<OutType, uint16_t>::value) {
if (is_bf16_out) {
tmp1 = cpu_bf162float(output[i]);
tmp2 = cpu_bf162float(output_ref[i]);
} else {
tmp1 = cpu_half2float(output[i]);
tmp2 = cpu_half2float(output_ref[i]);
}
} else {
assert(false && "ERROR: unsupported output type");
cout << "ERROR: unsupported output type" << endl;
}
assert(fabs(tmp1 - tmp2) < 1e-3);
if (fabs(tmp1 - tmp2) >= 1e-3) {
cout << i << " " << tmp1 << " " << tmp2 << endl;
}
}
}
Expand All @@ -267,6 +288,19 @@ int run_benchmark(
#pragma omp barrier
#endif
if (fbgemm_get_thread_num() == 0) {
if (std::is_same<OutType, float>::value) {
cout << "out type fp32";
} else if (std::is_same<OutType, uint16_t>::value) {
if (is_bf16_out) {
cout << "out type bf16";
} else {
cout << "out type fp16";
}
} else {
assert(false && "ERROR: unsupported output type");
cout << "ERROR: unsupported output type" << endl;
}

if (has_weight) {
cout << setw(16) << "SLW(WEIGHTED) ";
} else {
Expand Down Expand Up @@ -332,7 +366,8 @@ int main() {
#ifdef _OPENMP
#pragma omp parallel if (stress_multi_threading)
#endif
run_benchmark(
#if defined(OUT_TYPE_FLOAT16)
run_benchmark<float16>(
batch_size,
num_rows,
embedding_dim,
Expand All @@ -341,22 +376,46 @@ int main() {
false,
false,
stress_multi_threading);

#else
run_benchmark<float>(
batch_size,
num_rows,
embedding_dim,
average_len,
false,
false,
false,
stress_multi_threading);
#endif
if (stress_multi_threading) {
return 0;
}

cout << "64 bit indices with prefetching, ";
run_benchmark(
#if defined(OUT_TYPE_FLOAT16)
run_benchmark<float16>(
batch_size, num_rows, embedding_dim, average_len, false, false, true);

#else
run_benchmark<float>(
batch_size, num_rows, embedding_dim, average_len, false, false, true);
#endif
cout << "32 bit indices, ";
run_benchmark(
#if defined(OUT_TYPE_FLOAT16)
run_benchmark<float16>(
batch_size, num_rows, embedding_dim, average_len, false, true);
#else
run_benchmark<float>(
batch_size, num_rows, embedding_dim, average_len, false, true);
#endif

cout << "32 bit indices with prefetching, ";
run_benchmark(
#if defined(OUT_TYPE_FLOAT16)
run_benchmark<float16>(
batch_size, num_rows, embedding_dim, average_len, false, true, true);
#else
run_benchmark<float>(
batch_size, num_rows, embedding_dim, average_len, false, true, true);
#endif

// running with normalize by lengths
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,
Expand Down

0 comments on commit f1b6b7c

Please sign in to comment.