Skip to content

Commit

Permalink
fix bug in fprop and started working on fwdDeriv
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 25, 2023
1 parent 0b9c8b4 commit cb37e42
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 52 deletions.
1 change: 1 addition & 0 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import SciLean.Tactic.FTrans.Basic
import SciLean.Tactic.FProp.Notation
import SciLean.FTrans.FDeriv.Basic
import SciLean.FunctionSpaces.Differentiable.Basic
import SciLean.FTrans.CDeriv.Basic

/-!
Expand Down
60 changes: 51 additions & 9 deletions SciLean/FTrans/FDeriv/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Mathlib.Analysis.Calculus.Deriv.Basic
import Mathlib.Analysis.Calculus.Deriv.Inv


import SciLean.FunctionSpaces.ContinuousLinearMap.Basic
import SciLean.FunctionSpaces.ContinuousLinearMap.Notation
import SciLean.FunctionSpaces.Differentiable.Basic
import SciLean.Tactic.FTrans.Basic

Expand All @@ -35,11 +35,43 @@ theorem fderiv.id_rule
: (fderiv K fun x : X => x) = fun _ => fun dx =>L[K] dx
:= by ext x dx; simp


theorem fderiv.const_rule (x : X)
: (fderiv K fun _ : Y => x) = fun _ => fun dx =>L[K] 0
:= by ext x dx; simp

theorem fderiv.comp_rule_at
(x : X)
(g : X → Y) (hg : DifferentiableAt K g x)
(f : Y → Z) (hf : DifferentiableAt K f (g x))
: (fderiv K fun x : X => f (g x)) x
=
let y := g x
fun dx =>L[K]
let dy := fderiv K g x dx
let dz := fderiv K f y dy
dz :=
by
rw[show (fun x => f (g x)) = f ∘ g by rfl]
rw[fderiv.comp x hf hg]
ext dx; simp

theorem fderiv.comp_rule
(g : X → Y) (hg : Differentiable K g)
(f : Y → Z) (hf : Differentiable K f)
: (fderiv K fun x : X => f (g x))
=
fun x =>
let y := g x
fun dx =>L[K]
let dy := fderiv K g x dx
let dz := fderiv K f y dy
dz :=
by
funext x;
rw[show (fun x => f (g x)) = f ∘ g by rfl]
rw[fderiv.comp x (hf (g x)) (hg x)]
ext dx; simp


theorem fderiv.let_rule_at
(x : X)
Expand All @@ -54,12 +86,14 @@ theorem fderiv.let_rule_at
fun dx =>L[K]
let dy := fderiv K g x dx
let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy)
dz :=
by
dz :=
by
have h : (fun x => f x (g x)) = (fun xy : X×Y => f xy.1 xy.2) ∘ (fun x => (x, g x)) := by rfl
rw[h]
rw[fderiv.comp x hf (DifferentiableAt.prod (by simp) hg)]
rw[DifferentiableAt.fderiv_prod (by simp) hg]
conv =>
lhs
rw[h]
rw[fderiv.comp x hf (DifferentiableAt.prod (by simp) hg)]
rw[DifferentiableAt.fderiv_prod (by simp) hg]
ext dx; simp[ContinuousLinearMap.comp]
rfl

Expand Down Expand Up @@ -107,12 +141,19 @@ theorem fderiv.pi_rule

open Lean Meta Qq

def fderiv.discharger (e : Expr) : SimpM (Option Expr) :=
FTrans.tacticToDischarge (Syntax.mkLit ``tacticDifferentiable "differentiable") e
def fderiv.discharger (e : Expr) : SimpM (Option Expr) := do
withTraceNode `fderiv_discharger (fun _ => return s!"discharge {← ppExpr e}") do
let cache := (← get).cache
let config : FProp.Config := {}
let state : FProp.State := { cache := cache }
let (proof?, state) ← FProp.fprop e |>.run config |>.run state
modify (fun simpState => { simpState with cache := state.cache })
return proof?

open Lean Elab Term FTrans
def fderiv.ftransExt : FTransExt where
ftransName := ``fderiv

getFTransFun? e :=
if e.isAppOf ``fderiv then

Expand All @@ -131,6 +172,7 @@ def fderiv.ftransExt : FTransExt where

identityRule := .some <| .thm ``fderiv.id_rule
constantRule := .some <| .thm ``fderiv.const_rule
compRule := .some <| .thm ``fderiv.comp_rule
lambdaLetRule := .some <| .thm ``fderiv.let_rule
lambdaLambdaRule := .some <| .thm ``fderiv.pi_rule

Expand Down
85 changes: 58 additions & 27 deletions SciLean/FTrans/FDeriv/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import SciLean.Profile

open SciLean

-- #profile_this_file

set_option profiler true

variable {K : Type _} [NontriviallyNormedField K]

Expand All @@ -14,24 +17,23 @@ variable {ι : Type _} [Fintype ι]

variable {E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, NormedSpace K (E i)]

example : NormedCommRing K := by infer_instance
example : NormedAlgebra K K := by infer_instance

example
: fderiv K (fun (x : K) => x * x * x)
=
fun x => fun dx =>L[K] dx * x + dx * x :=
by
ftrans only
set_option trace.Meta.Tactic.simp.rewrite true in
set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.lsimp.pre true in
set_option trace.Meta.Tactic.lsimp.step true in
set_option trace.Meta.Tactic.lsimp.post true in
ftrans only
ext x; simp
#exits
-- example
-- : fderiv K (fun (x : K) => x * x * x)
-- =
-- fun x => fun dx =>L[K] dx * x + dx * x :=
-- by
-- ftrans only
-- set_option trace.Meta.Tactic.simp.rewrite true in
-- set_option trace.Meta.Tactic.simp.discharge true in
-- set_option trace.Meta.Tactic.simp.unify true in
-- set_option trace.Meta.Tactic.lsimp.pre true in
-- set_option trace.Meta.Tactic.lsimp.step true in
-- set_option trace.Meta.Tactic.lsimp.post true in
-- ftrans only
-- ext x; simp

example : Differentiable K fun x : K => x := by fprop

example
: fderiv K (fun (x : K) => x + x)
Expand All @@ -42,30 +44,59 @@ by
ftrans only
ext x; simp


example
: fderiv K (fun (x : K) => x + x + x)
=
fun x => fun dx =>L[K]
dx + dx + dx :=
by
ftrans only
ftrans only;
ext x; simp

example
: fderiv K (fun (x : K) => x * x * x * x)
=
fun x => fun dx =>L[K] 0 :=
by
conv =>
lhs
ftrans only
sorry


set_option trace.Meta.Tactic.simp.rewrite true in
example
: fderiv K (fun (x : K) => x + x + x + x)
=
fun x => fun dx =>L[K]
dx + dx + dx + dx :=
by
ftrans only
ext x; simp
ftrans

example
: fderiv K (fun (x : K) => x + x + x + x + x)

variable {K : Type _} [NontriviallyNormedField K]

variable {E : Type _} [NormedAddCommGroup E] [NormedSpace K E]

variable {F : Type _} [NormedAddCommGroup F] [NormedSpace K F]

variable {G : Type _} [NormedAddCommGroup G] [NormedSpace K G]

variable {f f₀ f₁ g : E → F}

theorem fderiv_add'
(hf : Differentiable K f) (hg : Differentiable K g) :
fderiv K (fun y => f y + g y)
=
fun x =>
fun dx =>L[K]
fderiv K f x dx + fderiv K g x dx := sorry

example (x : K)
: fderiv K (fun (x : K) => x + x + x + x + x) x
=
fun x => fun dx =>L[K]
fun dx =>L[K]
dx + dx + dx + dx + dx :=
by
ftrans only
ext x; simp
by
simp (discharger:=fprop) only [fderiv_add', fderiv_id']
dsimp
Loading

0 comments on commit cb37e42

Please sign in to comment.