Skip to content

Commit

Permalink
Aggressively compute as much of Zeta as possible during recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Jan 31, 2024
1 parent 6070b05 commit 7b3bd2f
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 64 deletions.
72 changes: 45 additions & 27 deletions include/Corr3.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,24 @@ class BaseCorr3
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3);

template <int B, int M, int C>
void splitC2CellsOrCalculateGn(
double splitC2CellsOrCalculateGn(
const BaseCell<C>& c1, const std::vector<const BaseCell<C>*>& c2list,
const MetricHelper<M,0>& metric,
std::vector<const BaseCell<C>*>& newc2list, bool& anysplit1,
BaseMultipoleScratch& mp);

template <int B, int M, int C>
void multipoleFinish(const BaseCell<C>& c1, const std::vector<const BaseCell<C>*>& c2list,
const MetricHelper<M,0>& metric, BaseMultipoleScratch& mp);
const MetricHelper<M,0>& metric, BaseMultipoleScratch& mp,
int mink_zeta);

template <int B, int M, int C>
void multipoleFinish(const BaseCell<C>& c1,
const std::vector<const BaseCell<C>*>& c2list,
const std::vector<const BaseCell<C>*>& c3list,
const MetricHelper<M,0>& metric, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3);
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int mink_zeta);

template <int C>
void finishProcess(const BaseCell<C>& c1, const BaseCell<C>& c2, const BaseCell<C>& c3,
Expand All @@ -145,13 +147,15 @@ class BaseCorr3
{ doCalculateGn(c1, c2, rsq, r, logr, k, mp); }

template <int C>
void calculateZeta(const BaseCell<C>& c1, BaseMultipoleScratch& mp)
{ doCalculateZeta(c1, mp); }
void calculateZeta(const BaseCell<C>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta)
{ doCalculateZeta(c1, mp, kstart, mink_zeta); }

template <int C>
void calculateZeta(const BaseCell<C>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3)
{ doCalculateZeta(c1, ordered, mp2, mp3); }
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta)
{ doCalculateZeta(c1, ordered, mp2, mp3, kstart, mink_zeta); }

protected:

Expand Down Expand Up @@ -195,16 +199,22 @@ class BaseCorr3
const BaseCell<ThreeD>& c1, const BaseCell<ThreeD>& c2,
double rsq, double r, double logr, int k, BaseMultipoleScratch& mp) =0;

virtual void doCalculateZeta(const BaseCell<Flat>& c1, BaseMultipoleScratch& mp) =0;
virtual void doCalculateZeta(const BaseCell<Sphere>& c1, BaseMultipoleScratch& mp) =0;
virtual void doCalculateZeta(const BaseCell<ThreeD>& c1, BaseMultipoleScratch& mp) =0;
virtual void doCalculateZeta(const BaseCell<Flat>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta) =0;
virtual void doCalculateZeta(const BaseCell<Sphere>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta) =0;
virtual void doCalculateZeta(const BaseCell<ThreeD>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta) =0;

virtual void doCalculateZeta(const BaseCell<Flat>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3) =0;
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta) =0;
virtual void doCalculateZeta(const BaseCell<Sphere>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3) =0;
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta) =0;
virtual void doCalculateZeta(const BaseCell<ThreeD>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3) =0;
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta) =0;

protected:

Expand Down Expand Up @@ -289,11 +299,13 @@ class Corr3 : public BaseCorr3
double rsq, double r, double logr, int k, BaseMultipoleScratch& mp);

template <int C>
void calculateZeta(const BaseCell<C>& c1, BaseMultipoleScratch& mp);
void calculateZeta(const BaseCell<C>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta);

template <int C>
void calculateZeta(const BaseCell<C>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3);
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta);

// Note: op= only copies _data. Not all the params.
void operator=(const Corr3<D1,D2,D3>& rhs);
Expand Down Expand Up @@ -346,22 +358,28 @@ class Corr3 : public BaseCorr3
double rsq, double r, double logr, int k, BaseMultipoleScratch& mp)
{ calculateGn(c1, c2, rsq, r, logr, k, mp); }

void doCalculateZeta(const BaseCell<Flat>& c1, BaseMultipoleScratch& mp)
{ calculateZeta(c1, mp); }
void doCalculateZeta(const BaseCell<Sphere>& c1, BaseMultipoleScratch& mp)
{ calculateZeta(c1, mp); }
void doCalculateZeta(const BaseCell<ThreeD>& c1, BaseMultipoleScratch& mp)
{ calculateZeta(c1, mp); }
void doCalculateZeta(const BaseCell<Flat>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta)
{ calculateZeta(c1, mp, kstart, mink_zeta); }
void doCalculateZeta(const BaseCell<Sphere>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta)
{ calculateZeta(c1, mp, kstart, mink_zeta); }
void doCalculateZeta(const BaseCell<ThreeD>& c1, BaseMultipoleScratch& mp,
int kstart, int mink_zeta)
{ calculateZeta(c1, mp, kstart, mink_zeta); }

void doCalculateZeta(const BaseCell<Flat>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3)
{ calculateZeta(c1, ordered, mp2, mp3); }
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta)
{ calculateZeta(c1, ordered, mp2, mp3, kstart, mink_zeta); }
void doCalculateZeta(const BaseCell<Sphere>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3)
{ calculateZeta(c1, ordered, mp2, mp3); }
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta)
{ calculateZeta(c1, ordered, mp2, mp3, kstart, mink_zeta); }
void doCalculateZeta(const BaseCell<ThreeD>& c1, int ordered,
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3)
{ calculateZeta(c1, ordered, mp2, mp3); }
BaseMultipoleScratch& mp2, BaseMultipoleScratch& mp3,
int kstart, int mink_zeta)
{ calculateZeta(c1, ordered, mp2, mp3, kstart, mink_zeta); }

// These are usually allocated in the python layer and just built up here.
// So all we have here is a bare pointer for each of them.
Expand Down
Loading

0 comments on commit 7b3bd2f

Please sign in to comment.