Skip to content

Commit

Permalink
clean up wave equation example
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 21, 2024
1 parent aa69907 commit e81643b
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 17 deletions.
1 change: 1 addition & 0 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import SciLean.Data.DataArray
import SciLean.Data.DataArray.DataArray
import SciLean.Data.DataArray.Operations
import SciLean.Data.DataArray.PlainDataType
import SciLean.Data.DataArray.RevDeriv
import SciLean.Data.DataArray.VecN
import SciLean.Data.Function
import SciLean.Data.IndexType
Expand Down
31 changes: 31 additions & 0 deletions SciLean/Analysis/Calculus/RevFDerivProj.lean
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,37 @@ by
simp[revFDerivProjUpdate,revFDerivProj,add_assoc,neg_pull,mul_assoc,smul_push]


@[fun_trans]
theorem HDiv.hDiv.arg_a0.revFDerivProj_rule
(f : X → K) (y : K) (hf : Differentiable K f) :
(revFDerivProj K Unit fun x => f x / y)
=
fun x =>
let ydf := revFDerivProj K Unit f x
(ydf.1 / y,
fun _ dx' => (1 / (conj y)) • (ydf.2 () dx')) :=
by
unfold revFDerivProj
fun_trans (disch:=apply hx); simp[oneHot, structMake,revFDerivProjUpdate,revFDerivProj,smul_push]


@[fun_trans]
theorem HDiv.hDiv.arg_a0.revFDerivProjUpdate_rule
(f : X → K) (y : K) (hf : Differentiable K f) :
(revFDerivProjUpdate K Unit fun x => f x / y)
=
fun x =>
let ydf := revFDerivProjUpdate K Unit f x
(ydf.1 / y,
fun _ dx' dx =>
let c := (1 / (conj y))
((ydf.2 () (c • dx') dx))) :=
by
unfold revFDerivProjUpdate
fun_trans (disch:=assumption)
simp[revFDerivProjUpdate,revFDerivProj,add_assoc,neg_pull,mul_assoc,smul_push]


-- HPow.hPow -------------------------------------------------------------------
--------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion SciLean/Meta/SimpCore.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ attribute [simp_core] add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_
attribute [simp_core] Nat.succ_sub_succ_eq_sub Nat.cast_ofNat

-- simp theorems for `Prod`
attribute [simp_core] Prod.mk.eta Prod.fst_zero Prod.snd_zero Prod.mk_add_mk Prod.mk_mul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk -- Prod.smul_mk
attribute [simp_core] Prod.mk.eta Prod.fst_zero Prod.snd_zero Prod.mk_add_mk Prod.mk_mul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk Prod.smul_mk

-- simp theorems for `Equiv`
attribute [simp_core] Equiv.invFun_as_coe Equiv.symm_symm
Expand Down
32 changes: 16 additions & 16 deletions examples/WaveEquation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,60 +5,60 @@ open SciLean

set_default_scalar Float

def _root_.Fin.shift (i : Fin n) (j : Nat) : Fin n := ⟨(i.1+j)%j, sorry_proof⟩
variable {n : Nat}
set_option synthInstance.maxSize 1000

def _root_.Fin.shift (i : Fin n) (j : Nat) : Fin n := ⟨(i.1+j)%n, sorry_proof⟩

-- set_option trace.Meta.synthInstance true in
def H (m k : Float) (x p : Float^[n]) : Float :=
let Δx := 1.0/n.toFloat
(Δx/(2*m)) * ‖p‖₂² + (Δx * k/2) * (∑ i, ‖x[i.shift 1] - x[i]‖₂²)


approx solver (m k : Float)
:= odeSolve (λ t (x,p) => ( ∇ (p':=p), H (n:=n) m k x p',
-∇ (x':=x), H (n:=n) m k x' p))
by
-- Unfold Hamiltonian definition and compute gradients
-- Unfold Hamiltonian definition
unfold H

-- compute derivatives
autodiff

-- Apply RK4 method
conv =>
pattern (odeSolve _)
-- apply RK4 method
conv in odeSolve _ =>
rw[odeSolve_fixed_dt (R:=Float) rungeKutta4 sorry_proof]

-- approximate limit by picking concrete `n`
approx_limit steps sorry_proof



def main : IO Unit := do

let substeps := 1
let m := 1.0
let k := 10000.0
let k := 40000.0

let N : Nat := 100
-- have h : Nonempty (Fin N) := sorry

let Δt := 0.1
let x₀ : Float^[N] := ⊞ (i : Fin N) => (Scalar.sin (i.1.toFloat/10))
let p₀ : Float^[N] := ⊞ (i : Fin N) => (0 : Float)
let mut t := 0
let mut (x,p) := (x₀, p₀)

for i in [0:100] do
for i in [0:1000] do

(x, p) := solver m k (substeps,()) t (t+Δt) (x, p)
(x,p) := solver m k (substeps,()) t (t+Δt) (x, p)
t += Δt

let M : Nat := 20
for m in IndexType.univ (Fin M) do
for n in IndexType.univ (Fin N) do
for m in fullRange (Fin M) do
for n in fullRange (Fin N) do
let xi := x[n]
if (2*m.1.toFloat - M.toFloat)/(M.toFloat) - xi < 0 then
IO.print "x"
else
IO.print "."

IO.println ""


-- #eval main

0 comments on commit e81643b

Please sign in to comment.