Skip to content

Commit

Permalink
revFDerivProj rules for ArrayType.get
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 20, 2024
1 parent fbff9a6 commit 845a8b0
Show file tree
Hide file tree
Showing 11 changed files with 796 additions and 47 deletions.
1 change: 1 addition & 0 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ import SciLean.MeasureTheory.WeakIntegral
-- import SciLean.Meta.DerivingOp
import SciLean.Meta.GenerateAddGroupHomSimp
import SciLean.Meta.GenerateFunProp
import SciLean.Meta.GenerateFunTrans
import SciLean.Meta.GenerateLinearMapSimp
import SciLean.Meta.Notation.Do
import SciLean.Meta.SimpAttr
Expand Down
64 changes: 52 additions & 12 deletions SciLean/Analysis/Calculus/FwdFDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,33 @@ theorem pi_rule

open SciLean

-- Prod.mk -----------------------------------v---------------------------------
-- of linear function ----------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem fwdFDeriv_linear
(f : X → Y) (hf : IsContinuousLinearMap K f) :
fwdFDeriv K f
=
fun x dx => (f x, f dx) := by unfold fwdFDeriv; fun_trans


-- Prod.mk ---------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule
(g : X → Y) (hg : Differentiable K g)
(f : X → Z) (hf : Differentiable K f) :
fwdFDeriv K (fun x => (g x, f x))
=
fun x dx =>
let ydy := fwdFDeriv K g x dx
let zdz := fwdFDeriv K f x dx
((ydy.1, zdz.1), (ydy.2, zdz.2)) := by
unfold fwdFDeriv; fun_trans


@[fun_trans]
theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule_at (x : X)
(g : X → Y) (hg : DifferentiableAt K g x)
Expand All @@ -137,6 +161,13 @@ theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule_at (x : X)
-- Prod.fst --------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem Prod.fst.arg_self.fwdFDeriv_rule :
fwdFDeriv K (fun xy : X×Y => xy.1)
=
fun xy dxy => (xy.1, dxy.1) := by
unfold fwdFDeriv; fun_trans

@[fun_trans]
theorem Prod.fst.arg_self.fwdFDeriv_rule_at (x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x) :
Expand All @@ -151,6 +182,14 @@ theorem Prod.fst.arg_self.fwdFDeriv_rule_at (x : X)
-- Prod.snd --------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem Prod.snd.arg_self.fwdFDeriv_rule :
fwdFDeriv K (fun xy : X×Y => xy.2)
=
fun xy dxy => (xy.2, dxy.2) := by
unfold fwdFDeriv; fun_trans


@[fun_trans]
theorem Prod.snd.arg_self.fwdFDeriv_rule_at (x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x) :
Expand Down Expand Up @@ -239,6 +278,13 @@ theorem HSMul.hSMul.arg_a0a1.fwdFDeriv_rule_at (x : X)
-- HDiv.hDiv -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem HDiv.hDiv.arg_a0.fwdFDeriv_rule (y : K) :
(fwdFDeriv K fun x => x / y)
=
fun x dx => (x / y, dx / y) := by
unfold fwdFDeriv; fun_trans

@[fun_trans]
theorem HDiv.hDiv.arg_a0a1.fwdFDeriv_rule_at (x : X)
(f : X → K) (g : X → K)
Expand All @@ -262,18 +308,12 @@ theorem HDiv.hDiv.arg_a0a1.fwdFDeriv_rule_at (x : X)
--------------------------------------------------------------------------------

@[fun_trans]
def HPow.hPow.arg_a0.fwdFDeriv_rule_at (n : Nat) (x : X)
(f : X → K) (hf : DifferentiableAt K f x) :
fwdFDeriv K (fun x => f x ^ n) x
def HPow.hPow.arg_a0.fwdFDeriv_rule (n : Nat) :
fwdFDeriv K (fun x : K => x ^ n)
=
fun dx =>
let ydy := fwdFDeriv K f x dx
(ydy.1 ^ n, n * ydy.2 * (ydy.1 ^ (n-1))) := by
unfold fwdFDeriv;
funext dx; simp
induction n
case zero => simp
case h _ => simp[pow_succ]; fun_trans; sorry_proof
fun x dx : K =>
(x ^ n, n * dx * (x ^ (n-1))) := by
unfold fwdFDeriv; fun_trans


-- IndexType.sum ----------------------------------------------------------------
Expand Down
9 changes: 5 additions & 4 deletions SciLean/Analysis/Calculus/Notation/Gradient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ elab_rules (kind:=gradNotation1) : term
let XY ← mkArrow X Y
-- Y might also be infered by the function `f`
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
let sX ← exprToSyntax X
let .some (_,Y) := (← inferType fExpr).arrow?
| return ← throwUnsupportedSyntax
let sX ← exprToSyntax X
let sK ← exprToSyntax K
let sY ← exprToSyntax Y
if (← isDefEq K Y) then
elabTerm (← `(fgradient (X:=$sX) $f $x $xs*)) none false
elabTerm (← `(fgradient (X:=$sX) (K:=$sK) $f $x $xs*)) none false
else
elabTerm (← `(adjointFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none false

elabTerm (← `(adjointFDeriv (X:=$sX) (Y:=$sY) defaultScalar% $f $x $xs*)) none false

| `(∇ $f) => do
let K ← elabTerm (← `(defaultScalar%)) none
Expand Down
11 changes: 11 additions & 0 deletions SciLean/Analysis/Calculus/RevFDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ variable
{E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, AdjointSpace K (E i)] [∀ i, CompleteSpace (E i)]


-- of linear function ----------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem revFDeriv_linear
(f : X → Y) (hf : IsContinuousLinearMap K f) :
revFDeriv K f
=
fun x => (f x, adjoint K f) := by unfold revFDeriv; fun_trans


-- Prod.mk ----------------------------------- ---------------------------------
--------------------------------------------------------------------------------

Expand Down
34 changes: 15 additions & 19 deletions SciLean/Analysis/Calculus/RevFDerivProj.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ namespace SciLean
set_option deprecated.oldSectionVars true

variable
(K I : Type _) [RCLike K]
{X : Type _} [NormedAddCommGroup X] [AdjointSpace K X]
{Y : Type _} [NormedAddCommGroup Y] [AdjointSpace K Y]
{Z : Type _} [NormedAddCommGroup Z] [AdjointSpace K Z]
{W : Type _} [NormedAddCommGroup W] [AdjointSpace K W]
{ι : Type _} [IndexType ι] [DecidableEq ι]
{κ : Type _} [IndexType κ] [DecidableEq κ]
{E : Type _} {EI : I → Type _}
(K I : Type) [RCLike K]
{X : Type} [NormedAddCommGroup X] [AdjointSpace K X]
{Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y]
{Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z]
{W : Type} [NormedAddCommGroup W] [AdjointSpace K W]
{ι : Type} [IndexType ι] [DecidableEq ι]
{κ : Type} [IndexType κ] [DecidableEq κ]
{E : Type} {EI : I → Type}
[StructType E I EI] [IndexType I] [DecidableEq I]
[NormedAddCommGroup E] [AdjointSpace K E] [∀ i, NormedAddCommGroup (EI i)] [∀ i, AdjointSpace K (EI i)]
[VecStruct K E I EI] -- todo: define AdjointSpaceStruct
{F J : Type _} {FJ : J → Type _}
{F J : Type} {FJ : J → Type}
[StructType F J FJ] [IndexType J] [DecidableEq J]
[NormedAddCommGroup F] [AdjointSpace K F] [∀ j, NormedAddCommGroup (FJ j)] [∀ j, AdjointSpace K (FJ j)]
[VecStruct K F J FJ] -- todo: define AdjointSpaceStruct
Expand Down Expand Up @@ -329,6 +329,7 @@ set_option deprecated.oldSectionVars true

variable
{K : Type} [RCLike K]
{ι : Type} [IndexType ι] [DecidableEq ι]
{X : Type} [NormedAddCommGroup X] [AdjointSpace K X]
{Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y]
{Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z]
Expand All @@ -339,7 +340,7 @@ variable
[NormedAddCommGroup Y'] [AdjointSpace K Y'] [∀ i, NormedAddCommGroup (YI i)] [∀ i, AdjointSpace K (YI i)] [VecStruct K Y' Yi YI]
[NormedAddCommGroup Z'] [AdjointSpace K Z'] [∀ i, NormedAddCommGroup (ZI i)] [∀ i, AdjointSpace K (ZI i)] [VecStruct K Z' Zi ZI]
{W : Type} [NormedAddCommGroup W] [AdjointSpace K W]
{ι : Type} [IndexType ι]




Expand Down Expand Up @@ -760,18 +761,13 @@ def HPow.hPow.arg_a0.revFDerivProjUpdate_rule

section IndexTypeSum

variable {ι : Type} [IndexType ι]


@[fun_trans]
theorem IndexType.sum.arg_f.revFDerivProj_rule [DecidableEq ι]
theorem IndexType.sum.arg_f.revFDerivProj_rule
(f : X → ι → Y') (hf : ∀ i, Differentiable K (fun x => f x i)) :
revFDerivProj K Yi (fun x => ∑ i, f x i)
=
fun x =>
-- this is not optimal
-- we should have but right now there is no appropriate StrucLike instance
-- let ydf := revFDerivProj K Yi f x
let ydf := revFDerivProjUpdate K (ι×Yi) f x
(∑ i, ydf.1 i,
fun j dy =>
Expand Down Expand Up @@ -855,10 +851,10 @@ theorem dite.arg_te.revFDerivProjUpdate_rule
section InnerProductSpace

variable
{R : Type _} [RealScalar R]
{R : Type} [RealScalar R]
-- {K : Type _} [Scalar R K]
{X : Type _} [NormedAddCommGroup X] [AdjointSpace R X]
{Y : Type _} [NormedAddCommGroup Y] [AdjointSpace R Y]
{X : Type} [NormedAddCommGroup X] [AdjointSpace R X]
{Y : Type} [NormedAddCommGroup Y] [AdjointSpace R Y]

-- Inner -----------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down
45 changes: 45 additions & 0 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import SciLean.Data.ArrayType.Algebra
import SciLean.Analysis.Convenient.HasAdjDiff
import SciLean.Analysis.AdjointSpace.Adjoint
import SciLean.Analysis.Calculus.RevFDerivProj

import SciLean.Meta.GenerateAddGroupHomSimp

Expand Down Expand Up @@ -250,6 +251,50 @@ theorem ArrayType.ofFn.arg_f.adjoint_rule :
end OnAdjointSpace


section OnAdjointSpace

variable
[NormedAddCommGroup Elem] [AdjointSpace K Elem] [CompleteSpace Elem]
{I : Type} [IndexType I] [DecidableEq I]
{E : I → Type} [∀ i, NormedAddCommGroup (E i)] [∀ i, AdjointSpace K (E i)]
[∀ i, CompleteSpace (E i)] [StructType Elem I E] [VecStruct K Elem I E]
{W : Type} [NormedAddCommGroup W] [AdjointSpace K W] [CompleteSpace W]


@[fun_trans]
theorem ArrayType.get.arg_cont.revFDerivProj_rule (i : Idx)
(cont : W → Cont) (hf : Differentiable K cont) :
revFDerivProj K I (fun w => ArrayType.get (cont w) i)
=
fun w : W =>
let xi := revFDerivProj K (Idx×I) cont w
(ArrayType.get xi.1 i, fun (j : I) (de : E j) =>
xi.2 (i,j) de) := by
unfold revFDerivProj; fun_trans[oneHot]
funext x
fun_trans
funext i de
congr
funext i
split_ifs
· congr; funext i; aesop
· aesop


@[fun_trans]
theorem ArrayType.get.arg_cont.revFDerivProjUpdate_rule (i : Idx)
(cont : W → Cont) (hf : Differentiable K cont) :
revFDerivProjUpdate K I (fun w => ArrayType.get (cont w) i)
=
fun w : W =>
let xi := revFDerivProjUpdate K (Idx×I) cont w
(ArrayType.get xi.1 i, fun (j : I) (de : E j) dw =>
xi.2 (i,j) de dw) := by unfold revFDerivProjUpdate; fun_trans


end OnAdjointSpace


#exit

@[fun_trans]
Expand Down
3 changes: 0 additions & 3 deletions SciLean/Data/StructType/Algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ instance [∀ i, MetricSpace (EI i)] [∀ j, MetricSpace (FJ j)] (i : I ⊕ J) :
dist_self := sorry_proof
dist_comm := sorry_proof
dist_triangle := sorry_proof
edist := match i with
| .inl _ => PseudoMetricSpace.edist
| .inr _ => PseudoMetricSpace.edist
edist_dist := sorry_proof
toUniformSpace := by infer_instance
uniformity_dist := sorry_proof
Expand Down
11 changes: 6 additions & 5 deletions SciLean/Tactic/FunTrans/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,16 @@ def applyApplyRule (funTransDecl : FunTransDecl) (e : Expr) : SimpM (Option Simp

return none

def applyPiRule (funTransDecl : FunTransDecl) (e : Expr) : SimpM (Option Simp.Result) := do
def applyPiRule (funTransDecl : FunTransDecl) (e f : Expr) : SimpM (Option Simp.Result) := do
let thms ← getLambdaTheorems funTransDecl.funTransName .pi e.getAppNumArgs

if thms.size = 0 then
trace[Meta.Tactic.fun_trans] "missing pi rule to transform `{← ppExpr e}`"
return none

for thm in thms do
if let .some r ← tryTheorem? e (.decl thm.thmName) then
let .pi id_f := thm.thmArgs | continue
if let .some r ← tryTheoremWithHint e (.decl thm.thmName) #[(id_f, f)] then
return r

return none
Expand All @@ -253,7 +254,7 @@ def applyMorTheorems (funTransDecl : FunTransDecl) (e : Expr) (fData : FunProp.F
match ← fData.isMorApplication with
| .none => return none
| .underApplied =>
applyPiRule funTransDecl e
applyPiRule funTransDecl e (← fData.toExpr)
| .overApplied =>
let .some (f,g) ← fData.peeloffArgDecomposition | return none
applyCompRule funTransDecl e f g
Expand Down Expand Up @@ -359,7 +360,7 @@ def tryTheorems (funTransDecl : FunTransDecl) (e : Expr) (fData : FunProp.Functi
return none
| .gt =>
trace[Meta.Tactic.fun_trans] s!"adding argument to later use {← ppOrigin' thm.thmOrigin}"
if let .some r ← applyPiRule funTransDecl e then
if let .some r ← applyPiRule funTransDecl e (← fData.toExpr) then
return r
continue
| .eq =>
Expand Down Expand Up @@ -574,7 +575,7 @@ partial def funTrans (e : Expr) : SimpM Simp.Step := do
| .lam f =>
trace[Meta.Tactic.fun_trans.step] "lam case on {← ppExpr f}"
let e := e.setArg funTransDecl.funArgId f -- update e with reduced f
toStep <| applyPiRule funTransDecl e
toStep <| applyPiRule funTransDecl e f
| .data fData =>
let e := e.setArg funTransDecl.funArgId (← fData.toExpr) -- update e with reduced f

Expand Down
9 changes: 5 additions & 4 deletions SciLean/Tactic/FunTrans/Theorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ inductive LambdaTheoremArgs
| letE (fArgId gArgId : Nat)

/-- Pi theorem e.g. `fderiv ℝ fun x y => f x y = ...` -/
| pi
| pi (fArgId : Nat)
deriving Inhabited, BEq, Repr, Hashable

/-- Tag for one of the 5 basic lambda theorems -/
Expand All @@ -66,7 +66,7 @@ def LambdaTheoremArgs.type (t : LambdaTheoremArgs) : LambdaTheoremType :=
| .comp .. => .comp
| .letE .. => .letE
| .apply => .apply
| .pi => .pi
| .pi .. => .pi

/-- Decides whether `f` is a function corresponding to one of the lambda theorems. -/
def detectLambdaTheoremArgs (f : Expr) (ctxVars : Array Expr) :
Expand All @@ -91,8 +91,9 @@ def detectLambdaTheoremArgs (f : Expr) (ctxVars : Array Expr) :
let .some argId_f := ctxVars.findIdx? (fun x => x == (.fvar fId)) | return none
let .some argId_g := ctxVars.findIdx? (fun x => x == (.fvar gId)) | return none
return .some <| .letE argId_f argId_g
| .lam _ _ (.app (.app (.fvar _) (.bvar 1)) (.bvar 0)) _ =>
return .some .pi
| .lam _ _ (.app (.app (.fvar fId) (.bvar 1)) (.bvar 0)) _ =>
let .some argId_f := ctxVars.findIdx? (fun x => x == (.fvar fId)) | return none
return .some <| .pi argId_f
| _ => return none
| _ => return none

Expand Down
Loading

0 comments on commit 845a8b0

Please sign in to comment.