Skip to content

Commit

Permalink
save/restore for a cfloat, but a fixed one
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravenwater committed Aug 16, 2023
1 parent d08cd31 commit e68bf50
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
75 changes: 43 additions & 32 deletions include/universal/blas/serialization/datafile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,28 +373,49 @@ namespace sw { namespace universal { namespace blas {
}

template<typename Scalar>
void restoreVector(std::istream& istr, uint32_t nrOfElements) {
void restoreVector(std::istream& istr, uint32_t nrElements) {
sw::universal::blas::vector<Scalar>* v = new sw::universal::blas::vector<Scalar>;
add<sw::universal::blas::vector<Scalar>>(*v);
float item{ 0 };
for (unsigned i = 0; i < nrOfElements; ++i) {
for (unsigned i = 0; i < nrElements; ++i) {
istr >> item;
v->push_back(item);
}
// std::cout << "restored vector is : " << *v << '\n';
}
template<typename Scalar>
void restoreMatrix(std::istream& istr, uint32_t nrOfElements) { // we know that blas::matrix uses a vector for storage
void restoreMatrix(std::istream& istr, uint32_t nrElements) { // we know that blas::matrix uses a vector for storage
sw::universal::blas::matrix<Scalar>* v = new sw::universal::blas::matrix<Scalar>;
add<sw::universal::blas::matrix<Scalar>>(*v);
float item{ 0 };
for (unsigned i = 0; i < nrOfElements; ++i) {
for (unsigned i = 0; i < nrElements; ++i) {
istr >> item;
v->push_back(item);
}
// std::cout << "restored matrix is : " << *v << '\n';
}

template<typename Scalar>
void restoreCollection(std::istream& istr, uint32_t aggregationType, uint32_t nrElements) {
switch (aggregationType) {
case UNIVERSAL_AGGREGATE_SCALAR:
std::cout << "Creating a scalar\n";
break;
case UNIVERSAL_AGGREGATE_VECTOR:
std::cout << "Creating a vector\n";
restoreVector<Scalar>(istr, nrElements);
break;
case UNIVERSAL_AGGREGATE_MATRIX:
std::cout << "Creating a matrix\n";
restoreMatrix<Scalar>(istr, nrElements);
break;
case UNIVERSAL_AGGREGATE_TENSOR:
std::cout << "Creating a tensor\n";
break;
default:
std::cout << "unknown aggregate\n";
}
}
bool restore(std::istream& istr) {
uint32_t magic_number;
istr >> magic_number;
Expand All @@ -413,45 +434,35 @@ namespace sw { namespace universal { namespace blas {
for (uint32_t i = 0; i < nrParameters; ++i) {
istr >> parameter[i];
}
// std::cout << "typeId : " << typeId << std::endl;
// std::cout << "nr parameters : " << nrParameters << std::endl;
// for (uint32_t i = 0; i < nrParameters; ++i) {
// std::cout << "parameter[" << i << "] : " << parameter[i] << std::endl;
// }
std::cout << "typeId : " << typeId << std::endl;
std::cout << "nr parameters : " << nrParameters << std::endl;
for (uint32_t i = 0; i < nrParameters; ++i) {
std::cout << "parameter[" << i << "] : " << parameter[i] << std::endl;
}
// read the mandatory comment line
std::string aggregationTypeComment;
std::string token;
istr >> token; // pick up the comment token
std::getline(istr, aggregationTypeComment);
// std::cout << "comment line : " << aggregationTypeComment << std::endl;
uint32_t aggregationType, nrOfElements;
istr >> aggregationType >> nrOfElements;
std::cout << "comment line : " << aggregationTypeComment << std::endl;
uint32_t aggregationType, nrElements;
istr >> aggregationType >> nrElements;
// std::cout << "aggregationType : " << aggregationType << std::endl;
// std::cout << "nr of elements : " << nrOfElements << std::endl;
// std::cout << "nr of elements : " << nrElements << std::endl;
switch (typeId) {
case UNIVERSAL_NATIVE_INT8_TYPE:
create<int8_t>(aggregationType);
break;
case UNIVERSAL_NATIVE_FP32_TYPE:
// create<float>(aggregationType);
switch (aggregationType) {
case UNIVERSAL_AGGREGATE_SCALAR:
std::cout << "Creating a scalar\n";
break;
case UNIVERSAL_AGGREGATE_VECTOR:
std::cout << "Creating a vector\n";
restoreVector<float>(istr, nrOfElements);
break;
case UNIVERSAL_AGGREGATE_MATRIX:
std::cout << "Creating a matrix\n";
restoreMatrix<float>(istr, nrOfElements);
break;
case UNIVERSAL_AGGREGATE_TENSOR:
std::cout << "Creating a tensor\n";
break;
default:
std::cout << "unknown aggregate\n";
}
restoreCollection<float>(istr, aggregationType, nrElements);
break;
case UNIVERSAL_NATIVE_FP64_TYPE:
restoreCollection<double>(istr, aggregationType, nrElements);
break;
case UNIVERSAL_CFLOAT_TYPE:
// todo: is there a good way to synthesize a type from dynamic data?
using onecfloat = cfloat<16, 5, uint16_t, true, false, false>;
restoreCollection<onecfloat>(istr, aggregationType, nrElements);
break;
default:
std::cout << "unknown typeId : " << typeId << '\n';
Expand Down
10 changes: 6 additions & 4 deletions linalg/data/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ void TestSaveTypeId() {
blas::saveTypeId(std::cout, h);
}

template<typename Scalar>
void TestVectorSerialization() {
using namespace sw::universal;
sw::universal::blas::vector<float> xfp32(5);
gaussian_random(xfp32, 0.0, 0.1);
sw::universal::blas::vector<Scalar> v(5);
gaussian_random(v, 0.0, 0.1);
blas::datafile<blas::TextFormat> df;
df.add(xfp32);
df.add(v);
df.save(std::cout, false); // decimal format
std::stringstream s;
df.save(s, false); // decimal format
Expand Down Expand Up @@ -214,7 +215,8 @@ try {
// ReportNativeHexFormats();
// ReportNumberSystemFormats();

TestVectorSerialization();
//TestVectorSerialization<float>();
TestVectorSerialization<half>();
return 0;

TestSaveTypeId();
Expand Down

0 comments on commit e68bf50

Please sign in to comment.