Skip to content

Commit

Permalink
Fixes bug in copying Tabulated1D instances (needed explicit copy cons…
Browse files Browse the repository at this point in the history
…tructors).
  • Loading branch information
HunterBelanger committed Jul 1, 2024
1 parent cbc1ee8 commit 0fa2364
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 40 deletions.
7 changes: 7 additions & 0 deletions include/PapillonNDL/tabulated_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Tabulated1D : public Function1D {
Tabulated1D(Interpolation interp, const std::vector<double>& x,
const std::vector<double>& y);

Tabulated1D(const Tabulated1D& other);
Tabulated1D& operator=(const Tabulated1D& other);
Tabulated1D(Tabulated1D&&) = default;
Tabulated1D& operator=(Tabulated1D&&) = default;

~Tabulated1D() = default;

double operator()(double x) const override final {
Expand Down Expand Up @@ -272,6 +277,8 @@ class Tabulated1D : public Function1D {
std::vector<double> x_;
std::vector<double> y_;
std::vector<InterpolationRange> regions_;

void make_regions();
};

} // namespace pndl
Expand Down
91 changes: 51 additions & 40 deletions src/tabulated_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,7 @@ Tabulated1D::Tabulated1D(const std::vector<uint32_t>& NBT,
}

// Make 1D regions of all intervals
std::size_t low = 0;
std::size_t hi = 0;
for (std::size_t i = 0; i < breakpoints_.size(); i++) {
hi = breakpoints_[i];

try {
regions_.push_back(
InterpolationRange(interpolation_[i],
{x_.begin() + static_cast<std::ptrdiff_t>(low),
x_.begin() + static_cast<std::ptrdiff_t>(hi)},
{y_.begin() + static_cast<std::ptrdiff_t>(low),
y_.begin() + static_cast<std::ptrdiff_t>(hi)}));
} catch (PNDLException& error) {
std::string mssg = "The i = " + std::to_string(i) +
" InterpolationRange could not be constructed when "
"building Tabulated1D.";
error.add_to_exception(mssg);
throw error;
}

low = hi - 1;

// Check for discontinuity at region boundary
if (low < x.size() - 1) {
if (x[low] == x[low + 1]) low++;
}
}
make_regions();
}

Tabulated1D::Tabulated1D(Interpolation interp, const std::vector<double>& x,
Expand All @@ -102,19 +76,26 @@ Tabulated1D::Tabulated1D(Interpolation interp, const std::vector<double>& x,
}

// Make 1 1D region
const std::size_t hi = breakpoints_[0];
try {
regions_.push_back(InterpolationRange(
interpolation_[0],
{x_.begin(), x_.begin() + static_cast<std::ptrdiff_t>(hi)},
{y_.begin(), y_.begin() + static_cast<std::ptrdiff_t>(hi)}));
} catch (PNDLException& error) {
std::string mssg =
"The InterpolationRange could not be constructed when building "
"Tabulated1D.";
error.add_to_exception(mssg);
throw error;
}
make_regions();
}

Tabulated1D::Tabulated1D(const Tabulated1D& other)
: breakpoints_(other.breakpoints_),
interpolation_(other.interpolation_),
x_(other.x_),
y_(other.y_),
regions_() {
make_regions();
}

Tabulated1D& Tabulated1D::operator=(const Tabulated1D& other) {
this->breakpoints_ = other.breakpoints_;
this->interpolation_ = other.interpolation_;
this->x_ = other.x_;
this->y_ = other.y_;
this->regions_.clear();
make_regions();
return *this;
}

void Tabulated1D::linearize(double tolerance) {
Expand Down Expand Up @@ -163,4 +144,34 @@ Tabulated1D::InterpolationRange::InterpolationRange(Interpolation interp,
interpolator_.verify_y_grid(y_.begin(), y_.end());
}

void Tabulated1D::make_regions() {
std::size_t low = 0;
std::size_t hi = 0;
for (std::size_t i = 0; i < breakpoints_.size(); i++) {
hi = breakpoints_[i];

try {
regions_.push_back(
InterpolationRange(interpolation_[i],
{x_.begin() + static_cast<std::ptrdiff_t>(low),
x_.begin() + static_cast<std::ptrdiff_t>(hi)},
{y_.begin() + static_cast<std::ptrdiff_t>(low),
y_.begin() + static_cast<std::ptrdiff_t>(hi)}));
} catch (PNDLException& error) {
std::string mssg = "The i = " + std::to_string(i) +
" InterpolationRange could not be constructed when "
"building Tabulated1D.";
error.add_to_exception(mssg);
throw error;
}

low = hi - 1;

// Check for discontinuity at region boundary
if (low < x_.size() - 1) {
if (x_[low] == x_[low + 1]) low++;
}
}
}

} // namespace pndl

0 comments on commit 0fa2364

Please sign in to comment.