Skip to content

Commit

Permalink
adding assignment and conversion skeletons to dbns
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravenwater committed Aug 18, 2023
1 parent bc76813 commit 98b9dc1
Show file tree
Hide file tree
Showing 5 changed files with 671 additions and 200 deletions.
261 changes: 88 additions & 173 deletions include/universal/number/dbns/dbns_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dbns<nbits, fbbits, bt, xtra...>& maxneg(dbns<nbits, fbbits, bt, xtra...>& lmaxn
return lmaxneg;
}


// double-base logarithmic number system: bases 2^-1, and 3
template<unsigned _nbits, unsigned _fbbits, typename bt = uint8_t, auto... xtra>
class dbns {
static_assert(_nbits > _fbbits, "configuration not supported: too many first base bits leaving no bits for second base");
Expand Down Expand Up @@ -81,12 +81,13 @@ class dbns {
static constexpr int64_t min_exponent = (maxShift > 0) ? (-(1ll << leftShift)) : 0;
static constexpr int64_t max_exponent = (maxShift > 0) ? (1ll << leftShift) - 1 : 0;
static constexpr int rightShift = (fbbits == 0 ? 0 : (64 - fbbits));
static constexpr uint64_t FB_MASK = (rightShift > 0 ? (0xFFFF'FFFF'FFFF'FFFFull >> rightShift) : 0ull);
static constexpr uint64_t FB_MASK = (rightShift > 0 ? ((0xFFFF'FFFF'FFFF'FFFFull >> rightShift) << sbbits) : 0ull);
static constexpr uint64_t SB_MASK = (0xFFFF'FFFF'FFFF'FFFFull >> (64 - (nbits - fbbits - 1)));

using BlockBinary = blockbinary<nbits, bt, BinaryNumberType::Signed>; // sign + dbns exponent
using ExponentBlockBinary = blockbinary<nbits-1, bt, BinaryNumberType::Signed>; // just the dbns exponent

// the smallest value with this base set and the assumption that exponents are positive is 0b0.111.0000
static constexpr float base[2] = { 0.5f, 3.0f };

/// trivial constructor
Expand Down Expand Up @@ -233,7 +234,7 @@ class dbns {
return *this;
}
if (rhs.iszero()) {
#if LNS_THROW_ARITHMETIC_EXCEPTION
#if DBNS_THROW_ARITHMETIC_EXCEPTION
throw dbns_divide_by_zero();
#else
setnan();
Expand Down Expand Up @@ -337,94 +338,69 @@ class dbns {
return *this;
}
constexpr dbns& minpos() noexcept {
// minimum positive value has this bit pattern: 0-11...11-00...00, that is, sign = 0, first base = 11..11, second base = 00..00
// minimum positive value has this bit pattern: 0-11...11-00...00, that is, sign = 0, first base = 11..10, second base = 00..00
clear();
for (unsigned i = nbits - fbbits - 1; i < nbits - 1; ++i) {
for (unsigned i = sbbits + 1; i < nbits - 1; ++i) {
setbit(i, true);
}
return *this;
}
constexpr dbns& zero() noexcept {
// the zero value has this bit pattern: 0-100..00-00..000, sign = 0, msb = 1, rest 0
// the zero value has this bit pattern: 0-11..11-00..000, sign = 0, fbbits all 1, rest 0
clear();
setbit(nbits - 2, true); // msb = 1
if constexpr (1 == nrBlocks) {
setbits(FB_MASK);
}
else {
for (unsigned i = sbbits; i < nbits - 1; ++i) {
setbit(i, true);
}
}
return *this;
}
constexpr dbns& minneg() noexcept {
// minimum negative value has this bit pattern: 1-11...11-00...00, that is, sign = 0, first base = 11..11, second base = 00..00
clear();
for (unsigned i = nbits - fbbits - 1; i < nbits; ++i) {
setbit(i, true);
}
// minimum negative value has this bit pattern: 1-11...10-00...00, that is, sign = 0, first base = 11..10, second base = 00..00
minpos();
setbit(nbits - 1ull, true);
return *this;
}
constexpr dbns& maxneg() noexcept {
// maximum negative value has this bit pattern: 1-00..00-11...11, that is, sign = 0, first base = 00..00, second base = 11..11
clear();
for (unsigned i = 0; i < nbits - fbbits - 1; ++i) {
setbit(i, true);
}
maxpos();
setbit(nbits - 1ull, true); // sign = 1
return *this;
}

// selectors
constexpr bool iszero() const noexcept { // special encoding: 0.1000.0000
if constexpr (nrBlocks == 1) {
return (_block[MSB_UNIT] == MSU_ZERO);
}
else if constexpr (nrBlocks == 2) {
if constexpr (SPECIAL_BITS_TOGETHER) {
return (_block[0] == 0 && _block[1] == MSU_ZERO);
}
else {
return !sign() && _block[0] == MSB_BIT_MASK;
}
constexpr bool iszero() const noexcept { // special encoding: 0.11..11.0000
if constexpr (1 == nrBlocks) {
if (!at(nbits - 1) && ((_block[MSU] & FB_MASK) == FB_MASK) && ((_block[MSU] & SB_MASK) == 0)) return true;
}
else {
if constexpr (SPECIAL_BITS_TOGETHER) {
for (unsigned i = 0; i < nrBlocks - 1; ++i) {
if (_block[i] != 0) return false;
}
return (_block[MSB_UNIT] == MSU_ZERO); // this will cover the sign != 1 condition
for (unsigned i = 0; i < sbbits; ++i) {
if (at(i)) return false;
}
else {
for (unsigned i = 0; i < nrBlocks - 2; ++i) {
if (_block[i] != 0) return false;
}
return !sign() && _block[MSB_UNIT] == BLOCK_MSB_MASK;
for (unsigned i = sbbits; i < nbits - 1; ++i) {
if (!at(i)) return false;
}
// zero is sign bit is off, nan is sign bit is on
return !at(nbits - 1);
}
return false;
}
constexpr bool isneg() const noexcept { return sign(); }
constexpr bool ispos() const noexcept { return !sign(); }
constexpr bool isinf() const noexcept { return false; }
constexpr bool isnan() const noexcept { // special encoding
if constexpr (nrBlocks == 1) {
return (_block[MSB_UNIT] == MSU_NAN); // 1.1000.0000 is NaN
// 1.1111.0000 is NaN
for (unsigned i = 0; i < sbbits; ++i) {
if (at(i)) return false;
}
else if constexpr (nrBlocks == 2) {
if constexpr (SPECIAL_BITS_TOGETHER) {
return (_block[0] == 0 && _block[1] == MSU_NAN);
}
else {
return sign() && (_block[MSU - 1] == BLOCK_MSB_MASK);
}
}
else {
if constexpr (SPECIAL_BITS_TOGETHER) {
for (unsigned i = 0; i < nrBlocks - 1; ++i) {
if (_block[i] != 0) return false;
}
return (_block[MSB_UNIT] == MSU_NAN);
}
else {
for (unsigned i = 0; i < nrBlocks - 2; ++i) {
if (_block[i] != 0) return false;
}
return sign() && (_block[MSU - 1] == BLOCK_MSB_MASK);
}
for (unsigned i = sbbits; i < nbits - 1; ++i) {
if (!at(i)) return false;
}
// zero is sign bit is off, nan is sign bit is on
return at(nbits - 1);
}
constexpr bool sign() const noexcept {
return (SIGN_BIT_MASK & _block[MSU]) != 0;
Expand Down Expand Up @@ -460,15 +436,32 @@ class dbns {
}

constexpr uint64_t extractExponent(int base) const noexcept {
uint64_t bits = uint64_t(_block);
if (base == 0) {
bits >>= (nbits - fbbits - 1); // normalize the value
bits &= FB_MASK; // null the sign bit
if constexpr (1 == nrBlocks) {
uint64_t bits = static_cast<uint64_t>(_block[MSU]);
if (base == 0) {
bits &= FB_MASK;
bits >>= sbbits; // normalize the value
}
else if (base == 1) {
bits &= SB_MASK; // value is already normalized
}
return bits;
}
else if (base == 1) {
bits &= SB_MASK; // normalize the value
else {
uint64_t bits{ 0 };
if (0 == base) {
for (unsigned i = sbbits; i < nbits - 1; ++i) {
bits |= (at(i) ? (1 << (i - sbbits)) : 0);
}
}
else {
for (unsigned i = 0; i < sbbits; ++i) {
bits |= (at(i) ? (1 << i) : 0);
}
}
return bits;
}
return bits;
return 0;
}
explicit operator int() const noexcept { return to_signed<int>(); }
explicit operator long() const noexcept { return to_signed<long>(); }
Expand Down Expand Up @@ -545,6 +538,9 @@ class dbns {
}
template<typename Real>
CONSTEXPRESSION dbns& convert_ieee754(Real v) noexcept {
using std::abs;
using std::log2;
using std::round;
bool s{ false };
uint64_t unbiasedExponent{ 0 };
uint64_t rawFraction{ 0 };
Expand Down Expand Up @@ -581,114 +577,33 @@ class dbns {
}

// check if the value is in the representable range
// NOTE: this is required to protect the rounding code below, which only works for values between [minpos, maxpos]
// TODO: this is all incredibly slow as we are creating special values and converting them to Real to compare
if constexpr (behavior == Behavior::Saturating) {
dbns maxpos(SpecificValue::maxpos);
dbns maxneg(SpecificValue::maxneg);
Real absoluteValue = std::abs(v);
//std::cout << "maxpos : " << to_binary(maxpos) << " : " << maxpos << '\n';
if (v > 0 && v >= Real(maxpos)) {
return *this = maxpos;
}
if (v < 0 && v <= Real(maxneg)) {
return *this = maxneg;
// if (abs(v) < minpos()) {
// setzero();
// return *this;
// }
double fulle = -log2(abs(v));
// std::cout << "fulle : " << fulle << '\n';
double best_err = 1.0e10;
int32_t best_e0 = 500;
int32_t best_e1 = 500;
int32_t b0{ 1 }, b1{ 1 }; // exponent biases
double err{ 0.0 };
int32_t e0{ 0 }, e1{ 0 };
for (e1 = 0; e1 < SB_MASK; ++e1) {
e0 = static_cast<int32_t>(round((fulle - e1 * b1) / b0));
err = abs(fulle - (e0 * b0 + e1 * b1));
// std::cout << "e0 : " << e0 << " e1 : " << e1 << " err : " << err << '\n';
if (err < best_err) {
best_err = err;
best_e0 = e0;
best_e1 = e1;
}
dbns minpos(SpecificValue::minpos);
dbns<nbits + 1, fbbits + 1, bt, xtra...> halfMinpos(SpecificValue::minpos); // in log space
//std::cout << "minpos : " << minpos << '\n';
//std::cout << "halfMinpos : " << halfMinpos << '\n';
if (absoluteValue <= Real(halfMinpos)) {
setzero();
return *this;
}
else if (absoluteValue <= Real(minpos)) {
return *this = (v > 0 ? minpos : -minpos);
}
}

bool negative = (v < Real(0.0f));
v = (negative ? -v : v);
Real logv = std::log2(v);
if (logv == 0.0) {
_block.clear();
_block.setbit(nbits - 1, negative);
return *this;
}


ExponentBlockBinary dbnsExponent{ 0 };
extractFields(logv, s, unbiasedExponent, rawFraction, bits); // use native conversion
if (unbiasedExponent > 0) rawFraction |= (1ull << ieee754_parameter<Real>::fbits);
int radixPoint = ieee754_parameter<Real>::fbits - (static_cast<int>(unbiasedExponent) - ieee754_parameter<Real>::bias);

// our fixed-point has its radixPoint at fbbits
int shiftRight = radixPoint - int(fbbits);
if (shiftRight > 0) {
if (shiftRight > 63) {
// this shift degree would be undefined behavior, but the intended transformation is that we have no bits
rawFraction = 0;
}
else {
// we need to round the raw bits
// collect guard, round, and sticky bits
// this same logic will work for the case where
// we only have a guard bit and no round and/or sticky bits
// because the mask logic will make round and sticky both 0
// so no need to special case it
uint64_t mask = (1ull << (shiftRight - 1));
bool guard = (mask & rawFraction);
mask >>= 1;
bool round = (mask & rawFraction);
if (shiftRight > 1) {
mask = (0xFFFF'FFFF'FFFF'FFFFull << (shiftRight - 2));
mask = ~mask;
}
else {
mask = 0;
}
bool sticky = (mask & rawFraction);

rawFraction >>= shiftRight; // shift out the bits we are rounding away
bool lsb = (rawFraction & 0x1ul);
// ... lsb | guard round sticky round
// x 0 x x down
// 0 1 0 0 down round to even
// 1 1 0 0 up round to even
// x 1 0 1 up
// x 1 1 0 up
// x 1 1 1 up
if (guard) {
if (lsb && (!round && !sticky)) ++rawFraction; // round to even
if (round || sticky) ++rawFraction;
}
rawFraction = (s ? (~rawFraction + 1) : rawFraction); // if negative, map to two's complement
}
dbnsExponent.setbits(rawFraction);
}
else {
int shiftLeft = -shiftRight;
if (shiftLeft < (64 - ieee754_parameter<Real>::fbits)) { // what is the distance between the MSB and 64?
// no need to round, just shift the bits in place
rawFraction <<= shiftLeft;
rawFraction = (s ? (~rawFraction + 1) : rawFraction); // if negative, map to two's complement
dbnsExponent.setbits(rawFraction);
}
else {
// we need to project the bits we have on the fixpnt
for (unsigned i = 0; i < ieee754_parameter<Real>::fbits + 1; ++i) {
if (rawFraction & 0x01) {
dbnsExponent.setbit(i + shiftLeft);
}
rawFraction >>= 1;
}
if (s) dbnsExponent.twosComplement();
}
}
// std::cout << "dbns exponent : " << to_binary(dbnsExponent) << " : " << dbnsExponent << '\n';
_block = dbnsExponent;
setsign(negative);

e0 = best_e0;
e1 = best_e1;
// std::cout << "e0 : " << e0 << " e1 : " << e1 << " err : " << err << '\n';
e0 <<= sbbits;
_block.setblock(MSU, (s ? SIGN_BIT_MASK : 0) | e0 | e1);
return *this;
}

Expand Down
7 changes: 4 additions & 3 deletions linalg/data/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void TestVectorSerialization() {
df.save(s, false); // decimal format
df.clear();
df.restore(s);
df.save(std::cout, false);
df.save(std::cout, true);
}

template<typename Scalar>
Expand All @@ -146,7 +146,7 @@ void TestMatrixSerialization() {
df.save(s, false); // decimal format
df.clear();
df.restore(s);
df.save(std::cout, false);
df.save(std::cout, true);
}

void TestSerialization() {
Expand Down Expand Up @@ -230,7 +230,8 @@ try {
// ReportNativeHexFormats();
// ReportNumberSystemFormats();

//TestVectorSerialization<float>();
TestVectorSerialization<float>();
TestVectorSerialization<dbns<8, 3>>();
TestMatrixSerialization<float>();
// TestVectorSerialization<half>();

Expand Down
Loading

0 comments on commit 98b9dc1

Please sign in to comment.