diff --git a/SciLean/Data/IndexType/Fold.lean b/SciLean/Data/IndexType/Fold.lean index 2f6ea2f7..2f1e268d 100644 --- a/SciLean/Data/IndexType/Fold.lean +++ b/SciLean/Data/IndexType/Fold.lean @@ -1,5 +1,6 @@ import SciLean.Analysis.Calculus.RevFDeriv import SciLean.Analysis.Calculus.FwdFDeriv +import SciLean.Analysis.Calculus.RevFDerivProj import SciLean.Data.IndexType.Operations import SciLean.Tactic.Autodiff import SciLean.Data.DataArray.DataArray @@ -184,51 +185,50 @@ theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_closures (r : Range I) dw + dw') := sorry_proof -/-- Reverse derivative of fold - version storing every point - use DataArray if possible -/ +/-- Reverse derivative of fold - version storing every point - store in Array if DataArray is not +available for `X` -/ @[fun_trans] -theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_data_array - {I: Type} [IndexType I] - {X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X] +theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_array (r : Range I) (op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i)) (init : W → X) (hinit : Differentiable R init) : - revFDeriv R (fun w => (.full : Range I).foldl (op w) (init w)) + revFDeriv R (fun w => r.foldl (op w) (init w)) = fun w => let idi := revFDeriv R init w - let xsx := (.full : Range I).foldl (fun (xs,x) i => - let xs := xs.set i x + let xsx := r.foldl (fun (xs,x) i => + let xs := xs.push (x,i) let x := op w x i - (xs,x)) ((0 : X^[I]), idi.1) + (xs,x)) ((#[] : Array (X×I)), idi.1) let xs := xsx.1 let x := xsx.2 (x, fun dx => - let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i => - let x := xs[i] + let dwx := xs.foldr (fun (x,i) (dw,dx) => let dwx := (revFDeriv R (fun (w,x) => op w x i) (w,x)).2 dx (dw + dwx.1, dwx.2)) (0, dx) let dw' := idi.2 dwx.2 dwx.1 + dw') := sorry_proof - -/-- Reverse derivative of fold - version storing every point - store in Array if DataArray is not -available for `X` -/ +/-- Reverse derivative of fold - version storing every point - use DataArray if possible -/ @[fun_trans] -theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_array (r : Range I) +theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_data_array + {I: Type} [IndexType I] + {X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X] (op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i)) (init : W → X) (hinit : Differentiable R init) : - revFDeriv R (fun w => r.foldl (op w) (init w)) + revFDeriv R (fun w => (.full : Range I).foldl (op w) (init w)) = fun w => let idi := revFDeriv R init w - let xsx := r.foldl (fun (xs,x) i => - let xs := xs.push (x,i) + let xsx := (.full : Range I).foldl (fun (xs,x) i => + let xs := xs.set i x let x := op w x i - (xs,x)) ((#[] : Array (X×I)), idi.1) + (xs,x)) ((0 : X^[I]), idi.1) let xs := xsx.1 let x := xsx.2 (x, fun dx => - let dwx := xs.foldr (fun (x,i) (dw,dx) => + let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i => + let x := xs[i] let dwx := (revFDeriv R (fun (w,x) => op w x i) (w,x)).2 dx (dw + dwx.1, dwx.2)) (0, dx) let dw' := idi.2 dwx.2 @@ -255,4 +255,32 @@ theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_linear (r : Range I) let dw' := idi.2 dwx.2 dwx.1 + dw') := sorry_proof + +@[fun_trans] +theorem IndexType.Range.foldl.arg_opinit.revFDerivProj_rule_data_array + {R : Type} [RCLike R] + {I : Type} [IndexType I] + {X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X] + {W : Type} [NormedAddCommGroup W] [AdjointSpace R W] [CompleteSpace W] + (op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i)) + (init : W → X) (hinit : Differentiable R init) : + revFDerivProj R Unit (fun w => (.full : Range I).foldl (op w) (init w)) + = + fun w => + let idi := revFDeriv R init w + let xsx := (.full : Range I).foldl (fun (xs,x) i => + let xs := xs.set i x + let x := op w x i + (xs,x)) ((0 : X^[I]), idi.1) + let xs := xsx.1 + let x := xsx.2 + (x, fun _ dx => + let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i => + let x := xs[i] + let dwx : W×X := (revFDerivProjUpdate R Unit (fun (w,x) => op w x i) (w,x)).2 () dx (dw,0) + dwx) (0, dx) + let dw' := idi.2 dwx.2 + dwx.1 + dw') := sorry_proof + + -- TODO: add checkpointing version diff --git a/SciLean/Data/IndexType/Iterator.lean b/SciLean/Data/IndexType/Iterator.lean index 30b78d51..9a6653e3 100644 --- a/SciLean/Data/IndexType/Iterator.lean +++ b/SciLean/Data/IndexType/Iterator.lean @@ -62,5 +62,5 @@ def sprod (i : Iterator I) (j : Iterator J) [FirstLast I I] [FirstLast J J] : -- todo: implement this and provide a better implementation of IndexType instance for Sum -private def ofSum [FirstLast α α] [FirstLast β β] (s : Iterator (α ⊕ β)) : - ((Iterator α × Range β)) ⊕ ((Iterator β × Range α)) := sorry +-- private def ofSum [FirstLast α α] [FirstLast β β] (s : Iterator (α ⊕ β)) : +-- ((Iterator α × Range β)) ⊕ ((Iterator β × Range α)) := sorry diff --git a/SciLean/Data/IndexType/Operations.lean b/SciLean/Data/IndexType/Operations.lean index afebc936..5d5e5845 100644 --- a/SciLean/Data/IndexType/Operations.lean +++ b/SciLean/Data/IndexType/Operations.lean @@ -30,12 +30,13 @@ def reduceMD {m} [Monad m] (r : Range ι) (f : ι → α) (op : α → α → m match first? r with | none => return default | .some fst => do - r.foldlM (fun a i => op a (f i)) (f fst) + let mut a := (f fst) + for i in Iterator.val fst r do + a ← op a (f i) + return a def reduceD (r : Range ι) (f : ι → α) (op : α → α → α) (default : α) : α := Id.run do - match first? r with - | none => return default - | .some fst => do r.foldl (fun a i => op a (f i)) (f fst) + r.reduceMD f (fun x y => pure (op x y)) default abbrev reduce [Inhabited α] (r : Range ι) (f : ι → α) (op : α → α → α) : α := r.reduceD f op default @@ -51,7 +52,7 @@ variable {ι : Type*} [IndexType ι] abbrev foldlM {m} [Monad m] (op : α → ι → m α) (init : α) : m α := Range.full.foldlM op init -abbrev foldl (op : α → ι → α) (init : α) : α := Id.run do +abbrev foldl (op : α → ι → α) (init : α) : α := Range.full.foldl op init abbrev reduceMD {m} [Monad m] (f : ι → α) (op : α → α → m α) (default : α) : m α :=