Skip to content

Commit

Permalink
WIP: custom dot product quantization experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravenwater committed Aug 26, 2023
1 parent eba6871 commit 9362ba7
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 9 deletions.
222 changes: 222 additions & 0 deletions benchmark/accuracy/quantization/mpdot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// mpfma.cpp: accuracy/quantization measurement of mixed-precision dot products
//
// Copyright (C) 2017-2023 Stillwater Supercomputing, Inc.
//
// This file is part of the universal numbers project, which is released under an MIT Open Source license.
#include <universal/utility/directives.hpp>
#include <universal/number/cfloat/cfloat.hpp>
#include <universal/number/lns/lns.hpp>
#include <universal/blas/blas.hpp>
#include <universal/verification/cfloat_test_suite.hpp>

constexpr unsigned FIELD_WIDTH = 8;

namespace sw {
namespace universal {

template<typename InputType, typename ProductType, typename AccumulationType, typename OutputType>
OutputType DotProductExperiment(const blas::vector<InputType>& a, const blas::vector<InputType>& b) {
size_t Na = size(a);
size_t Nb = size(b);
assert(Na == Nb && "vectors are not of the same length");

// upsample the inputs for the multiplication step
blas::vector<ProductType> aa(Na), cc(Nb);
aa = a; cc = b;
// element-wise product
cc *= aa;

// upsample to accumulation type
blas::vector<AccumulationType> acc(Na);
acc = cc;

OutputType result = sum(acc);

return result;
}

/// <summary>
/// generate the custom dot products
/// </summary>
/// <param name="data">input vectors</param>
/// <param name="dots">output dot product results</param>
template<typename InputType, typename ProductType, typename AccumulationType, typename OutputType>
void GenerateSpecializedDotProducts(const std::vector<sw::universal::blas::vector<InputType>>& data, std::vector<OutputType>& dots) {
using namespace sw::universal;
size_t N = size(data);
dots.resize(N);
for (size_t i = 0; i < N; ++i) {
auto result = DotProductExperiment<InputType, ProductType, AccumulationType, OutputType>(data[0], data[i]);
std::cout << "custom dot product : " << to_binary(result) << " : " << result << '\n';
}
}
} }

// generate a set of N vectors of length L in double as reference
void GenerateRandomVectors(unsigned N, unsigned L, std::vector<sw::universal::blas::vector<double>>& data) {
using namespace sw::universal;
blas::vector<double> reference_data(L);
data.resize(N);
double mean{ 0.0 }, stddev{ 1.0 };
for (unsigned i = 0; i < N; ++i) {
data[i].resize(L);
blas::gaussian_random(data[i], mean, stddev);
}
}

template<typename InputType>
void ConvertToInputType(const std::vector<sw::universal::blas::vector<double>>& data, std::vector<sw::universal::blas::vector<InputType>>& idata) {
size_t N = size(data);
size_t L = size(data[0]);
idata.resize(N);
for (size_t i = 0; i < size(data); ++i) {
idata[i].resize(L);
idata[i] = data[i];
}
}

template<typename DataType>
void PrintRandomVectors(const std::string& header, const std::vector<sw::universal::blas::vector<DataType>>& data) {
std::cout << "\n>>>>>>> " << header << " <<<<<<<\n";
for (auto e : data) {
std::cout << e << '\n';
}
}


/// <summary>
/// generate the reference dot products
/// </summary>
/// <param name="data">input vectors</param>
/// <param name="dots">output dot product results</param>
void GenerateReferenceDotProducts(const std::vector<sw::universal::blas::vector<double>>& data, std::vector<double>& dots) {
using namespace sw::universal;
size_t N = size(data);
dots.resize(N);
for (size_t i = 0; i < N; ++i) {
auto result = DotProductExperiment<double, double, double, double>(data[0], data[i]);
std::cout << "reference dot product : " << to_binary(result) << " : " << result << '\n';
}
}



void GenerateParetoSamples(const std::vector<sw::universal::blas::vector<double>>& data) {
using namespace sw::universal;

// InputTypes
// using fp4e3m0_ff = cfloat<4, 3, uint8_t, false, false, false>; // not supported by cfloat<>
// using fp4e3m0_tt = cfloat<4, 3, uint8_t, true, true, false>; // not supported by cfloat<>
using fp4e2m1_ff = cfloat<4, 2, uint8_t, false, false, false>;
using fp6e3m2_ff = cfloat<6, 3, uint8_t, false, false, false>;
using fp6e4m1_ff = cfloat<6, 4, uint8_t, false, false, false>;
using fp8e4m3_ff = cfloat<8, 4, uint8_t, false, false, false>;
using fp8e5m2_ff = cfloat<8, 5, uint8_t, false, false, false>;
using fp10e5m4_ff = cfloat<10, 5, uint8_t, false, false, false>;
using fp10e6m3_ff = cfloat<10, 6, uint8_t, false, false, false>;
using fp12e5m6_ff = cfloat<12, 5, uint8_t, false, false, false>;
using fp12e6m5_ff = cfloat<12, 6, uint8_t, false, false, false>;
using fp12e7m4_ff = cfloat<12, 7, uint8_t, false, false, false>;
using fp16e5ms10_ff = cfloat<16, 5, uint8_t, false, false, false>;
using fp16e8ms7_ff = cfloat<16, 8, uint8_t, false, false, false>;
using fp16e9ms6_ff = cfloat<16, 9, uint8_t, false, false, false>;
// subnormal and supernormal enabled
using fp4e2m1_tt = cfloat<4, 2, uint8_t, true, true, false>;
using fp6e3m2_tt = cfloat<6, 3, uint8_t, true, true, false>;
using fp6e4m1_tt = cfloat<6, 4, uint8_t, true, true, false>;
using fp8e4m3_tt = cfloat<8, 4, uint8_t, true, true, false>;
using fp8e5m2_tt = cfloat<8, 5, uint8_t, true, true, false>;
using fp10e5m4_tt = cfloat<10, 5, uint8_t, true, true, false>;
using fp10e6m3_tt = cfloat<10, 6, uint8_t, true, true, false>;
using fp12e5m6_tt = cfloat<12, 5, uint8_t, true, true, false>;
using fp12e6m5_tt = cfloat<12, 6, uint8_t, true, true, false>;
using fp12e7m4_tt = cfloat<12, 7, uint8_t, true, true, false>;
using fp16e5m10_tt = cfloat<16, 5, uint8_t, true, true, false>;
using fp16e8m7_tt = cfloat<16, 8, uint8_t, true, true, false>;
using fp16e9m6_tt = cfloat<16, 9, uint8_t, true, true, false>;

// ProductTypes
using fp7e3m3_ff = cfloat<7, 3, uint8_t, false, false, false>;
using fp9e4m4_ff = cfloat<9, 4, uint8_t, false, false, false>;
using fp9e6m2_ff = cfloat<9, 6, uint8_t, false, false, false>;
// subnormal and supernormal enabled
using fp7e3m3_tt = cfloat<7, 3, uint8_t, true, true, false>;
using fp9e4m4_tt = cfloat<9, 4, uint8_t, true, true, false>;
using fp9e6m2_tt = cfloat<9, 6, uint8_t, true, true, false>;

// DotProductExperiment<InputType, ProductType, AccumulationType, OutputType>;
size_t N = size(data);
PrintRandomVectors("Reference data set", data);
std::vector<double> referenceDots(N);
GenerateReferenceDotProducts(data, referenceDots);

std::vector < blas::vector<fp8e4m3_tt> > idata;
ConvertToInputType(data, idata);
PrintRandomVectors("InputType data set", idata);

std::vector< fp8e4m3_tt > dots;
GenerateSpecializedDotProducts< fp8e4m3_tt, fp8e4m3_tt, fp8e4m3_tt, fp8e4m3_tt >(idata, dots);

}


template<typename RepresentationType, typename AccumulationType>
void StatisticalSampling(double mean, double stddev) {
using namespace sw::universal;
std::cout << "representation type : " << symmetry_range<RepresentationType>() << '\n';
std::cout << "accumulation type : " << symmetry_range<AccumulationType>() << '\n';
unsigned nrSamples{ 10000 };
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 50, mean, stddev);
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 100, mean, stddev);
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 500, mean, stddev);
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 1000, mean, stddev);
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 2000, mean, stddev);
QuantizationExperiment<RepresentationType, AccumulationType>(nrSamples, 4000, mean, stddev);
}

void print_cmdline(int argc, char** argv) {
std::cout << "cmd: ";
for (int i = 0; i < argc; ++i) {
std::cout << argv[i] << ' ';
}
std::cout << '\n';
}

int main(int argc, char** argv)
try {
using namespace sw::universal;

print_cmdline(argc, argv);

std::streamsize prec = std::cout.precision();
std::cout << std::setprecision(3);

std::vector<blas::vector<double>> data;
GenerateRandomVectors(5, 5, data);

GenerateParetoSamples(data);

std::cout << std::setprecision(prec);

return EXIT_SUCCESS;
}
catch (char const* msg) {
std::cerr << msg << std::endl;
return EXIT_FAILURE;
}
catch (const sw::universal::universal_arithmetic_exception& err) {
std::cerr << "Uncaught universal arithmetic exception: " << err.what() << std::endl;
return EXIT_FAILURE;
}
catch (const sw::universal::universal_internal_exception& err) {
std::cerr << "Uncaught universal internal exception: " << err.what() << std::endl;
return EXIT_FAILURE;
}
catch (const std::runtime_error& err) {
std::cerr << "Uncaught runtime exception: " << err.what() << std::endl;
return EXIT_FAILURE;
}
catch (...) {
std::cerr << "Caught unknown exception" << std::endl;
return EXIT_FAILURE;
}
6 changes: 4 additions & 2 deletions benchmark/accuracy/quantization/mpfma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@ try {
std::streamsize prec = std::cout.precision();
std::cout << std::setprecision(3);

// generate a set of N vectors of length L in double as reference

using fp12 = cfloat<12, 5, uint16_t, true, true, false>; // accumulation type

double mean{ 0.0 }, stddev{ 1.0 };
double mean{ 0 }, stddev{ 1.0 };
StatisticalSampling<fp8e3m4, fp12>(mean, stddev);
StatisticalSampling<fp8e4m3, fp12>(mean, stddev);
StatisticalSampling<fp8e5m2, fp12>(mean, stddev);

// input fp<4,3>, multiply output fp<6,4>, accumulation output fp<6,4>

std::cout << std::setprecision(prec);

return EXIT_SUCCESS;
Expand Down
8 changes: 4 additions & 4 deletions include/universal/number/cfloat/cfloat_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ class cfloat {
value.setnormal();
value.setsign(rhs.sign());
value.setscale(rhs.scale());
constexpr unsigned rhsFbits = nnbits - 1ul - ees;
blockbinary<rhsFbits, bbt, BinaryNumberType::Signed> fraction;
//constexpr unsigned rhsFbits = nnbits - 1ul - ees;
//blockbinary<rhsFbits, bbt, BinaryNumberType::Signed> fraction;
//rhs.fraction<rhsFbits>(fraction);
std::cout << "fraction : " << to_binary(fraction) << '\n';
//std::cout << "fraction : " << to_binary(fraction) << '\n';
//value.setfraction(fraction);
convert(value, *this);
}
Expand Down Expand Up @@ -2595,7 +2595,7 @@ class cfloat {
std::cout << "fraction bits : " << to_binary(rawFraction, 32, true) << '\n';
#endif
// construct the target cfloat
uint64_t bits = (s ? 1ull : 0ull);
bits = (s ? 1ull : 0ull);
bits <<= es;
bits |= biasedExponent;
bits <<= fbits;
Expand Down
3 changes: 2 additions & 1 deletion include/universal/number/dbns/dbns_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace sw { namespace universal {
int exponentOverflowDuringSearch;
int roundingFailure;
};
inline std::ostream& operator<<(std::ostream& ostr, const DbnsArithmeticStatistics& stats) {

inline std::ostream& operator<<(std::ostream& ostr, const DbnsArithmeticStatistics stats) {
ostr << "Conversions : " << stats.conversionEvents << '\n';
ostr << "Exponent Overflow During Search : " << stats.exponentOverflowDuringSearch << '\n';
ostr << "Rounding Successes : " << (stats.conversionEvents - stats.roundingFailure) << '\n';
Expand Down
97 changes: 95 additions & 2 deletions static/dbns/arithmetic/multiplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//
// This file is part of the universal numbers project, which is released under an MIT Open Source license.
#include <universal/utility/directives.hpp>
#include <vector>
#include <algorithm>
// configure the number system
#define DBNS_THROW_ARITHMETIC_EXCEPTION 1
#include <universal/number/dbns/dbns.hpp>
Expand Down Expand Up @@ -65,6 +67,94 @@ namespace sw { namespace universal {
return nrOfFailedTestCases;
}

template<typename DbnsType>
struct DbnsSample {
DbnsSample(const DbnsType& a, const DbnsType& b, const DbnsType& c, const DbnsType& cref, double ref, int p, int v) : a{ a }, b{ b }, c{ c }, cref{ cref }, ref { ref }, patternOrder{ p }, valueOrder{ v } {}
DbnsType a, b, c, cref;
double ref;
int patternOrder;
int valueOrder;
};

template<typename DbnsType>
std::ostream& operator<<(std::ostream& ostr, const DbnsSample<DbnsType>& s) {
ostr << std::setw(10) << s.patternOrder << " : "
<< to_binary(s.a)
<< " * "
<< to_binary(s.b)
<< " = "
<< to_binary(s.c)
<< " : "
<< std::setw(10) << s.c
<< " : "
<< std::setw(10) << s.ref
<< " = "
<< std::setw(10) << s.a
<< " * "
<< std::setw(10) << s.b
<< " : "
<< to_binary(s.cref)
<< " : "
<< std::setw(10) << s.valueOrder;
if (s.c.isnan()) ostr << " : PASS"; else if (s.c == s.cref) ostr << " : PASS"; else ostr << " : FAIL";
return ostr;
}

template<typename DbnsType,
std::enable_if_t< is_dbns<DbnsType>, bool> = true
>
int GenerateOrdered(bool reportTestCases) {
using std::abs;
constexpr size_t nbits = DbnsType::nbits;
//constexpr size_t fbbits = DbnsType::fbbits;
//constexpr Behavior behavior = DbnsType::behavior;
//using bt = typename DbnsType::BlockType;
constexpr size_t NR_ENCODINGS = (1ull << nbits);
int nrOfFailedTestCases = 0;

std::vector<DbnsSample<DbnsType>> v;
DbnsType a{}, b{}, c{}, cref{}, maxvalue(SpecificValue::maxpos);
double maxpos = double(maxvalue);
for (size_t i = 0; i < NR_ENCODINGS; ++i) {
a.setbits(i);
double da = double(a);
for (size_t j = 0; j < NR_ENCODINGS; ++j) {
b.setbits(j);
double db = double(b);

double ref = da * db;
c = a * b;
cref = ref;
DbnsSample<DbnsType> s(a, b, c, cref, ref, i * NR_ENCODINGS + j, 0);
v.push_back(s);
}
}

std::sort(v.begin(), v.end(),
[](DbnsSample<DbnsType> a, DbnsSample<DbnsType> b) {
if (a.a.isnan() && !b.b.isnan()) {
return true;
}
else if (!a.a.isnan() && b.b.isnan()) {
return false;
}
else if (a.a.isnan() && b.b.isnan()) {
return false;
}
else {
return a.ref < b.ref;
}
});

// assigne the value order
for (unsigned valueOrder = 0; valueOrder < v.size(); ++valueOrder) {
v[valueOrder].valueOrder = valueOrder;
}
for (auto e : v) {
std::cout << e << '\n';
}
return nrOfFailedTestCases;
}
} }

// Regression testing guards: typically set by the cmake configuration, but MANUAL_TESTING is an override
Expand Down Expand Up @@ -104,8 +194,11 @@ try {
//using DBNS9_4 = dbns<9, 4, std::uint8_t>;
using DBNS16_5 = dbns<16, 5, std::uint16_t>;

nrOfFailedTestCases += ReportTestResult(VerifyMultiplication<DBNS4_2>(true), "dbns<4,2, uint8_t>", test_tag);
nrOfFailedTestCases += ReportTestResult(VerifyMultiplication<DBNS5_2>(true), "dbns<5,2, uint8_t>", test_tag);
GenerateOrdered<DBNS5_2>(false);
return 0;

// nrOfFailedTestCases += ReportTestResult(VerifyMultiplication<DBNS4_2>(true), "dbns<4,2, uint8_t>", test_tag);
// nrOfFailedTestCases += ReportTestResult(VerifyMultiplication<DBNS5_2>(true), "dbns<5,2, uint8_t>", test_tag);

{
float d{ 0 };
Expand Down

0 comments on commit 9362ba7

Please sign in to comment.