Skip to content

Commit

Permalink
refactor: state update step
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais committed Oct 13, 2024
1 parent b5c9d8e commit 616ae86
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions Batteries/Data/Random/MersenneTwister.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ structure State (cfg : Config) where
/-- Mersenne Twister initialization given an optional seed. -/
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config)
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg :=
⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
where
/-- Inner loop for Mersenne Twister initalization. -/
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) :=
Expand All @@ -81,24 +81,23 @@ where
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size
loop w v (by simp only [v, Array.size_push]; omega)

/-- Apply the twisting transformation to the given state. -/
@[specialize cfg] protected def State.twist (state : State cfg) : State cfg :=
let i := state.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else0, cfg.zero_lt_stateSize⟩
let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask
let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
⟨state.data.set i x, i'⟩

/-- Update the state by a number of generation steps (default 1). -/
@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg :=
loop state steps
where
/-- Inner loop for Mersenne Twister update. -/
@[inline] loop (s : State cfg) (c : Nat) : State cfg :=
if c = 0 then s else
let i := s.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else0, cfg.zero_lt_stateSize⟩
let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask
let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
loop ⟨s.data.set i x, i'⟩ (c-1)
@[inline] protected def State.update (state : State cfg) (steps := 1) : State cfg :=
if steps = 0 then state else state.twist.update (steps-1)

/-- Mersenne Twister iteration. -/
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg :=
let i := state.index
let s := state.update
let s := state.twist
(temper s.data[i], s)
where
/-- Tempering step for Mersenne Twister. -/
Expand Down

0 comments on commit 616ae86

Please sign in to comment.