Skip to content

Commit

Permalink
fix HISQ force
Browse files Browse the repository at this point in the history
make argument order consistent for overloads of fat7lDeriv
  • Loading branch information
jcosborn committed Dec 18, 2024
1 parent edfe03e commit 30500ea
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 15 deletions.
180 changes: 180 additions & 0 deletions src/examples/hisq_force.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# C.T. Peterson: force test inspired from conversation with Peter Boyle
# See Grid implementation here:
# -https://github.com/paboyle/Grid/blob/develop/tests/forces/Test_bdy.cc
import qex
import gauge/[hisqsmear]
import physics/[stagD,stagSolve]

qexInit()

defaultSetup()
var
sg = lo.newGauge()
sgl = lo.newGauge()
f = lo.newGauge()
ff = lo.newGauge()
p = lo.newGauge()
phi = lo.ColorVector()
psi = lo.ColorVector()
r = lo.newRNGField(RngMilc6,123456789)
mass = 0.1
eps = 0.001
spa = initSolverParams()
spf = initSolverParams()
info: PerfInfo
let
hisq = newHisq()
stag = newStag3(sg,sgl)
arsq = 1e-20
frsq = 1e-12

spa.r2req = arsq
spa.maxits = 10000
spf.r2req = frsq
spf.maxits = 10000
spf.verbosity = 1

# -- Generic

proc smearRephase(g: auto, sg,sgl: auto): auto {.discardable.} =
tic()
let smearedForce = hisq.smearGetForce(g,sg,sgl)
threads:
sg.setBC; sgl.setBC;
threadBarrier()
sg.stagPhase; sgl.stagPhase;
smearedForce

proc reTrMul(x,y:auto):auto =
var d: type(eval(toDouble(redot(x[0],y[0]))))
for ir in x: d += redot(x[ir].adj, y[ir])
result = simdSum(d)
x.l.threadRankSum(result)

# -- Action calculation

proc action(): float =
var s: float
stag.solve(psi, phi, -mass, spa)
threads:
var st = psi.norm2
threadMaster: s = st
result = 0.5*s

# -- Force calculation & momentum update

proc smearedOneAndThreeLinkForce(f: auto, smearedForce: proc, p: auto, g:auto) =
# reverse accumulation of the derivative
# 1. Dslash
var
f1 = f.newOneOf()
f3 = f.newOneOf()
ff = f.newOneOf()
t,t3: array[4,Shifter[typeof(p),typeof(p[0])]]
for mu in 0..<f.len:
t[mu] = newShifter(p,mu,1)
discard t[mu] ^* p
t3[mu] = newShifter(p,mu,3)
discard t3[mu] ^* p
const n = p[0].len
threads:
for mu in 0..<f.len:
for i in f[mu]:
forO a, 0, n-1:
forO b, 0, n-1:
f1[mu][i][a,b] := p[i][a] * t[mu].field[i][b].adj
f3[mu][i][a,b] := p[i][a] * t3[mu].field[i][b].adj

# 2. correcting phase
threads:
f1.setBC; f3.setBC;
threadBarrier()
f1.stagPhase; f3.stagPhase;
threadBarrier()
for mu in 0..<f.len:
for i in f[mu].odd:
f1[mu][i] *= -1
f3[mu][i] *= -1

# 3. smearing
ff.smearedForce(f1,f3)

# 4. Tₐ ReTr( Tₐ U F† )
threads:
for mu in 0..<f.len:
for i in f[mu]:
var s {.noinit.}: typeof(f[0][0])
s := ff[mu][i]*g[mu][i].adj
f[mu][i].projectTAH(s)

proc fforce(f: auto) =
tic()
let smearedForce = g.smearRephase(sg,sgl)
toc("fforce smear rephase")
stag.solve(psi, phi, mass, spf)
toc("fforce solve")
f.smearedOneAndThreeLinkForce(smearedForce, psi, g)
toc("fforce olf")

proc mdt() =
tic()
threads:
for mu in 0..<g.len:
for s in g[mu]:
g[mu][s] := exp(0.5*eps*p[mu][s])*g[mu][s]

proc mdv() =
let s = -0.5/mass
f.fforce()
threads:
for mu in 0..<f.len: f[mu] *= s

# -- Test

var p1: float
g.random
threads:
p.randomTAH r
psi.gaussian r
var p2t = 0.0
for i in 0..<p.len: p2t += p[i].norm2
threadMaster: p1 = 0.5*p2t
discard g.smearRephase(sg,sgl)
threads:
stag.D(phi, psi, -mass)
threadBarrier()
phi.odd := 0
psi := 0

# Calculate initial action
let s1 = action()
echo "ACTION 1: ", s1

# Update (leapfrog)
mdt(); mdv(); mdt();

# Calculate final action
var p2: float
discard g.smearRephase(sg,sgl)
let s2 = action()
echo "ACTION 2: ", s2
threads:
var p2t = 0.0
for i in 0..<p.len:
p2t += p[i].norm2
threadMaster: p2 = 0.5*p2t

# Calculate dS = P U dSdU
var dS: float
threads:
var dSt = 0.0
for mu in 0..<p.len:
dSt = dSt - reTrMul(p[mu],f[mu])
threadMaster: dS = dSt

# Compare differences
let dH = s2+p2-s1-p1
let (dSdt1,dSdt2) = (dS*eps,s2-s1)
echo "dt*dS/dt, dS, difference = ", dSdt1,", ", dSdt2, ", ", dSdt1-dSdt2

qexFinalize()
12 changes: 6 additions & 6 deletions src/gauge/fat7lderiv.nim
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ proc fat7lDeriv*(deriv: auto, gauge: auto, mid: auto, coef: Fat7lCoefs,

proc fat7lDeriv*(
deriv: auto,
mid: auto,
gauge: auto,
mid: auto,
coef: Fat7lCoefs,
lmid: auto,
lgauge: auto,
lmid: auto,
naik: float,
perf: var PerfInfo
) =
var (fx,fxl) = (newOneOf(mid),newOneOf(lmid))
fat7lderiv(fx,gauge,mid,coef,fxl,lgauge,lmid,naik,perf)
fat7lDeriv(fx,gauge,mid,coef,fxl,lgauge,lmid,naik,perf)
threads:
for mu in 0..<deriv.len:
for s in deriv[mu]:
Expand Down Expand Up @@ -313,7 +313,7 @@ when isMainModule:
let a = gc.gaugeAction2(fl)
let a2 = gc.gaugeAction2(fl2)
gc.gaugeDeriv2(fl, ch)
fat7lderiv(fd, g, ch, coef, info)
fat7lDeriv(fd, g, ch, coef, info)
check(a2-a, tol)

checkS("oneLink", 1)
Expand Down Expand Up @@ -354,13 +354,13 @@ when isMainModule:
let a2 = gc.gaugeAction2(fl2) + gc.gaugeAction2(ll2)
gc.gaugeDeriv2(fl, ch)
gc.gaugeDeriv2(ll, ld)
fat7lderiv(fd, ch, g, coef, ld, g, naik, info)
fat7lDeriv(fd, g, ch, coef, g, ld, naik, info)
#let bias = fl.norm2 / (fl.len*lo.physVol)
#let lbias = ll.norm2 / (ll.len*lo.physVol)
#let a = 0.5*(fl.norm2subtract(bias) + ll.norm2subtract(lbias))
#let a2 = 0.5*(fl2.norm2subtract(bias) + ll2.norm2subtract(lbias))
#echo a, " ", a2
#fat7lderiv(fd, fl, g, coef, ll, g, naik, info)
#fat7lDeriv(fd, g, fl, coef, g, ll, naik, info)
check(a2-a, tol)

coef.oneLink = 0.0
Expand Down
18 changes: 9 additions & 9 deletions src/gauge/hisqsmear.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ proc newHISQ*(lepage: float = 0.0; naik: float = 1.0): HisqCoefs =
result.fat7second.setHisqFat7(2.0-lepage,naik)

proc smearGetForce*[T](
self: HisqCoefs;
u: T;
self: HisqCoefs;
u: T;
su,sul: T;
displayPerformance: bool = false
): proc(dsdu: var T; dsdsu,dsdsul: T) =
Expand All @@ -26,25 +26,25 @@ proc smearGetForce*[T](
v = newOneOf(u)
w = newOneOf(u)
info: PerfInfo

# Smear
v.makeImpLinks(u,fat7l1,info) # First fat7
threads: # Unitary projection
for mu in 0..<w.len:
for mu in 0..<w.len:
for s in w[mu]: w[mu][s].projectU(v[mu][s])
makeImpLinks(su,w,fat7l2,sul,w,naik,info) # Second fat7

# Chain rule - retains a reference to u,su,sul
proc smearedForce(dsdu: var T; dsdsu,dsdsul: T) =
var
var
dsdx_dxdw = newOneOf(dsdu)
dsdx_dxdw_dwdv = newOneOf(dsdu)
dsdx_dxdw.fat7lderiv(dsdsu,su,fat7l2,dsdsul,sul,naik,info) # Second fat7
dsdx_dxdw.fat7lDeriv(su,dsdsu,fat7l2,sul,dsdsul,naik,info) # Second fat7
threads: # Unitary projection
for mu in 0..<dsdx_dxdw_dwdv.len:
for s in dsdx_dxdw_dwdv[mu]:
dsdx_dxdw_dwdv[mu][s].projectUderiv(w[mu][s],v[mu][s],dsdx_dxdw[mu][s])
dsdu.fat7lderiv(dsdx_dxdw_dwdv,u,fat7l1,info) # First fat7
dsdx_dxdw_dwdv[mu][s].projectUderiv(w[mu][s],v[mu][s],dsdx_dxdw[mu][s])
dsdu.fat7lDeriv(u,dsdx_dxdw_dwdv,fat7l1,info) # First fat7

if displayPerformance: echo $(info)
return smearedForce
Expand All @@ -54,8 +54,8 @@ if isMainModule:
let
defaultLat = @[8,8,8,8]
hisq = newHISQ()
defaultSetup()
var
(lo, g, r) = setupLattice(defaultLat)
sg = lo.newGauge()
sgl = lo.newGauge()
f = lo.newGauge()
Expand Down

0 comments on commit 30500ea

Please sign in to comment.