Skip to content

Commit

Permalink
fix IndexType.reduce and revDerivProj rule for IndexType.foldl
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 1, 2024
1 parent ae16b2b commit 43b2d1a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
66 changes: 47 additions & 19 deletions SciLean/Data/IndexType/Fold.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Analysis.Calculus.RevFDeriv
import SciLean.Analysis.Calculus.FwdFDeriv
import SciLean.Analysis.Calculus.RevFDerivProj
import SciLean.Data.IndexType.Operations
import SciLean.Tactic.Autodiff
import SciLean.Data.DataArray.DataArray
Expand Down Expand Up @@ -184,51 +185,50 @@ theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_closures (r : Range I)
dw + dw') := sorry_proof


/-- Reverse derivative of fold - version storing every point - use DataArray if possible -/
/-- Reverse derivative of fold - version storing every point - store in Array if DataArray is not
available for `X` -/
@[fun_trans]
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_data_array
{I: Type} [IndexType I]
{X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X]
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_array (r : Range I)
(op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i))
(init : W → X) (hinit : Differentiable R init) :
revFDeriv R (fun w => (.full : Range I).foldl (op w) (init w))
revFDeriv R (fun w => r.foldl (op w) (init w))
=
fun w =>
let idi := revFDeriv R init w
let xsx := (.full : Range I).foldl (fun (xs,x) i =>
let xs := xs.set i x
let xsx := r.foldl (fun (xs,x) i =>
let xs := xs.push (x,i)
let x := op w x i
(xs,x)) ((0 : X^[I]), idi.1)
(xs,x)) ((#[] : Array (X×I)), idi.1)
let xs := xsx.1
let x := xsx.2
(x, fun dx =>
let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i =>
let x := xs[i]
let dwx := xs.foldr (fun (x,i) (dw,dx) =>
let dwx := (revFDeriv R (fun (w,x) => op w x i) (w,x)).2 dx
(dw + dwx.1, dwx.2)) (0, dx)
let dw' := idi.2 dwx.2
dwx.1 + dw') := sorry_proof



/-- Reverse derivative of fold - version storing every point - store in Array if DataArray is not
available for `X` -/
/-- Reverse derivative of fold - version storing every point - use DataArray if possible -/
@[fun_trans]
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_array (r : Range I)
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_data_array
{I: Type} [IndexType I]
{X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X]
(op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i))
(init : W → X) (hinit : Differentiable R init) :
revFDeriv R (fun w => r.foldl (op w) (init w))
revFDeriv R (fun w => (.full : Range I).foldl (op w) (init w))
=
fun w =>
let idi := revFDeriv R init w
let xsx := r.foldl (fun (xs,x) i =>
let xs := xs.push (x,i)
let xsx := (.full : Range I).foldl (fun (xs,x) i =>
let xs := xs.set i x
let x := op w x i
(xs,x)) ((#[] : Array (X×I)), idi.1)
(xs,x)) ((0 : X^[I]), idi.1)
let xs := xsx.1
let x := xsx.2
(x, fun dx =>
let dwx := xs.foldr (fun (x,i) (dw,dx) =>
let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i =>
let x := xs[i]
let dwx := (revFDeriv R (fun (w,x) => op w x i) (w,x)).2 dx
(dw + dwx.1, dwx.2)) (0, dx)
let dw' := idi.2 dwx.2
Expand All @@ -255,4 +255,32 @@ theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_linear (r : Range I)
let dw' := idi.2 dwx.2
dwx.1 + dw') := sorry_proof


@[fun_trans]
theorem IndexType.Range.foldl.arg_opinit.revFDerivProj_rule_data_array
{R : Type} [RCLike R]
{I : Type} [IndexType I]
{X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X]
{W : Type} [NormedAddCommGroup W] [AdjointSpace R W] [CompleteSpace W]
(op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i))
(init : W → X) (hinit : Differentiable R init) :
revFDerivProj R Unit (fun w => (.full : Range I).foldl (op w) (init w))
=
fun w =>
let idi := revFDeriv R init w
let xsx := (.full : Range I).foldl (fun (xs,x) i =>
let xs := xs.set i x
let x := op w x i
(xs,x)) ((0 : X^[I]), idi.1)
let xs := xsx.1
let x := xsx.2
(x, fun _ dx =>
let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i =>
let x := xs[i]
let dwx : W×X := (revFDerivProjUpdate R Unit (fun (w,x) => op w x i) (w,x)).2 () dx (dw,0)
dwx) (0, dx)
let dw' := idi.2 dwx.2
dwx.1 + dw') := sorry_proof


-- TODO: add checkpointing version
4 changes: 2 additions & 2 deletions SciLean/Data/IndexType/Iterator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def sprod (i : Iterator I) (j : Iterator J) [FirstLast I I] [FirstLast J J] :


-- todo: implement this and provide a better implementation of IndexType instance for Sum
private def ofSum [FirstLast α α] [FirstLast β β] (s : Iterator (α ⊕ β)) :
((Iterator α × Range β)) ⊕ ((Iterator β × Range α)) := sorry
-- private def ofSum [FirstLast α α] [FirstLast β β] (s : Iterator (α ⊕ β)) :
-- ((Iterator α × Range β)) ⊕ ((Iterator β × Range α)) := sorry
11 changes: 6 additions & 5 deletions SciLean/Data/IndexType/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def reduceMD {m} [Monad m] (r : Range ι) (f : ι → α) (op : α → α → m
match first? r with
| none => return default
| .some fst => do
r.foldlM (fun a i => op a (f i)) (f fst)
let mut a := (f fst)
for i in Iterator.val fst r do
a ← op a (f i)
return a

def reduceD (r : Range ι) (f : ι → α) (op : α → α → α) (default : α) : α := Id.run do
match first? r with
| none => return default
| .some fst => do r.foldl (fun a i => op a (f i)) (f fst)
r.reduceMD f (fun x y => pure (op x y)) default

abbrev reduce [Inhabited α] (r : Range ι) (f : ι → α) (op : α → α → α) : α :=
r.reduceD f op default
Expand All @@ -51,7 +52,7 @@ variable {ι : Type*} [IndexType ι]
abbrev foldlM {m} [Monad m] (op : α → ι → m α) (init : α) : m α :=
Range.full.foldlM op init

abbrev foldl (op : α → ι → α) (init : α) : α := Id.run do
abbrev foldl (op : α → ι → α) (init : α) : α :=
Range.full.foldl op init

abbrev reduceMD {m} [Monad m] (f : ι → α) (op : α → α → m α) (default : α) : m α :=
Expand Down

0 comments on commit 43b2d1a

Please sign in to comment.