From 9e583efcea920afa13ee2a53069821a2297a94c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20G=2E=20Dorais?= Date: Thu, 12 Dec 2024 08:17:51 -0500 Subject: [PATCH] refactor: dependent fold for `Fin` (#1074) --- Batteries/Data/Fin/Basic.lean | 66 ++++++++-------- Batteries/Data/Fin/Fold.lean | 137 ++++++++++++++++------------------ 2 files changed, 97 insertions(+), 106 deletions(-) diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index 51fc68a28e..800d7c7c87 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2017 Robert Y. Lewis. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Robert Y. Lewis, Keeley Hoek, Mario Carneiro +Authors: Robert Y. Lewis, Keeley Hoek, Mario Carneiro, François G. Dorais, Quang Dao -/ import Batteries.Tactic.Alias import Batteries.Data.Array.Basic @@ -18,20 +18,6 @@ alias enum := Array.finRange @[deprecated (since := "2024-11-15")] alias list := List.finRange -/-- Heterogeneous fold over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`, where -`f 2 : α 3 → α 2`, `f 1 : α 2 → α 1`, etc. - -This is the dependent version of `Fin.foldr`. -/ -@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Sort _) - (f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 := - loop n (Nat.lt_succ_self n) init where - /-- Inner loop for `Fin.dfoldr`. - `Fin.dfoldr.loop n α f i h x = f 0 (f 1 (... (f i x)))` -/ - @[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α 0 := - match i with - | i + 1 => loop i (Nat.lt_of_succ_lt h) (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x) - | 0 => x - /-- Heterogeneous monadic fold over `Fin n` from right to left: ``` Fin.foldrM n f xₙ = do @@ -41,25 +27,33 @@ Fin.foldrM n f xₙ = do let x₀ : α 0 ← f 0 x₁ pure x₀ ``` -This is the dependent version of `Fin.foldrM`, defined using `Fin.dfoldr`. -/ -@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _) +This is the dependent version of `Fin.foldrM`. -/ +@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Type _) (f : ∀ (i : Fin n), α i.succ → m (α i.castSucc)) (init : α (last n)) : m (α 0) := - dfoldr n (fun i => m (α i)) (fun i x => x >>= f i) (pure init) + loop n (Nat.lt_succ_self n) init where + /-- + Inner loop for `Fin.dfoldrM`. + ``` + Fin.dfoldrM.loop n f i h xᵢ = do + let xᵢ₋₁ ← f (i-1) xᵢ + ... + let x₁ ← f 1 x₂ + let x₀ ← f 0 x₁ + pure x₀ + ``` + -/ + @[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α 0) := + match i with + | i + 1 => (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x) >>= loop i (Nat.lt_of_succ_lt h) + | 0 => pure x -/-- Heterogeneous fold over `Fin n` from the left: `foldl 3 f x = f 0 (f 1 (f 2 x))`, where -`f 0 : α 0 → α 1`, `f 1 : α 1 → α 2`, etc. +/-- Heterogeneous fold over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`, where +`f 2 : α 3 → α 2`, `f 1 : α 2 → α 1`, etc. -This is the dependent version of `Fin.foldl`. -/ -@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Sort _) - (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) := - loop 0 (Nat.zero_lt_succ n) init where - /-- Inner loop for `Fin.dfoldl`. `Fin.dfoldl.loop n α f i h x = f n (f (n-1) (... (f i x)))` -/ - @[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α (last n) := - if h' : i < n then - loop (i + 1) (Nat.succ_lt_succ h') (f ⟨i, h'⟩ x) - else - haveI : ⟨i, h⟩ = last n := by ext; simp; omega - _root_.cast (congrArg α this) x +This is the dependent version of `Fin.foldr`. -/ +@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Type _) + (f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 := + dfoldrM (m := Id) n α f init /-- Heterogeneous monadic fold over `Fin n` from left to right: ``` @@ -71,7 +65,7 @@ Fin.foldlM n f x₀ = do pure xₙ ``` This is the dependent version of `Fin.foldlM`. -/ -@[inline] def dfoldlM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _) +@[inline] def dfoldlM [Monad m] (n : Nat) (α : Fin (n + 1) → Type _) (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (init : α 0) : m (α (last n)) := loop 0 (Nat.zero_lt_succ n) init where /-- Inner loop for `Fin.dfoldlM`. @@ -89,3 +83,11 @@ This is the dependent version of `Fin.foldlM`. -/ else haveI : ⟨i, h⟩ = last n := by ext; simp; omega _root_.cast (congrArg (fun i => m (α i)) this) (pure x) + +/-- Heterogeneous fold over `Fin n` from the left: `foldl 3 f x = f 0 (f 1 (f 2 x))`, where +`f 0 : α 0 → α 1`, `f 1 : α 1 → α 2`, etc. + +This is the dependent version of `Fin.foldl`. -/ +@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Type _) + (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) := + dfoldlM (m := Id) n α f init diff --git a/Batteries/Data/Fin/Fold.lean b/Batteries/Data/Fin/Fold.lean index 51d2f0c53c..9d2d821c2f 100644 --- a/Batteries/Data/Fin/Fold.lean +++ b/Batteries/Data/Fin/Fold.lean @@ -8,98 +8,61 @@ import Batteries.Data.Fin.Basic namespace Fin -/-! ### dfoldr -/ +/-! ### dfoldrM -/ -theorem dfoldr_loop_zero (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : - dfoldr.loop n α f 0 (Nat.zero_lt_succ n) x = x := rfl +theorem dfoldrM_loop_zero [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (x) : + dfoldrM.loop n α f 0 h x = pure x := rfl -theorem dfoldr_loop_succ (f : (i : Fin n) → α i.succ → α i.castSucc) (h : i < n) (x) : - dfoldr.loop n α f (i+1) (Nat.add_lt_add_right h 1) x = - dfoldr.loop n α f i (Nat.lt_add_right 1 h) (f ⟨i, h⟩ x) := rfl +theorem dfoldrM_loop_succ [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (x) : + dfoldrM.loop n α f (i+1) h x = f ⟨i, by omega⟩ x >>= dfoldrM.loop n α f i (by omega) := rfl -theorem dfoldr_loop (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (h : i+1 ≤ n+1) (x) : - dfoldr.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x = - f 0 (dfoldr.loop n (α ∘ succ) (f ·.succ) i h x) := by +theorem dfoldrM_loop [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc)) + (x) : dfoldrM.loop (n+1) α f (i+1) h x = + dfoldrM.loop n (α ∘ succ) (f ·.succ) i (by omega) x >>= f 0 := by induction i with - | zero => rfl - | succ i ih => exact ih .. - -@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) : - dfoldr 0 α f x = x := rfl - -theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : - dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldr_loop .. - -theorem dfoldr_succ_last (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : - dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by - induction n with - | zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta] - | succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr - -theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : - dfoldr n α f x = dfoldrM (m:=Id) n α f x := rfl - -theorem dfoldr_eq_foldr (f : Fin n → α → α) (x : α) : dfoldr n (fun _ => α) f x = foldr n f x := by - induction n with - | zero => simp only [dfoldr_zero, foldr_zero] - | succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_apply, Function.comp_def, ih] - -/-! ### dfoldrM -/ + | zero => + rw [dfoldrM_loop_zero, dfoldrM_loop_succ, pure_bind] + conv => rhs; rw [← bind_pure (f 0 x)] + rfl + | succ i ih => + rw [dfoldrM_loop_succ, dfoldrM_loop_succ, bind_assoc] + congr; funext; exact ih .. @[simp] theorem dfoldrM_zero [Monad m] (f : (i : Fin 0) → α i.succ → m (α i.castSucc)) (x) : dfoldrM 0 α f x = pure x := rfl -theorem dfoldrM_succ [Monad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc)) - (x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldr_succ .. +theorem dfoldrM_succ [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc)) + (x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldrM_loop .. -theorem dfoldrM_eq_foldrM [Monad m] [LawfulMonad m] (f : (i : Fin n) → α → m α) (x : α) : +theorem dfoldrM_eq_foldrM [Monad m] [LawfulMonad m] (f : (i : Fin n) → α → m α) (x) : dfoldrM n (fun _ => α) f x = foldrM n f x := by - induction n generalizing x with + induction n with | zero => simp only [dfoldrM_zero, foldrM_zero] | succ n ih => simp only [dfoldrM_succ, foldrM_succ, Function.comp_def, ih] -/-! ### dfoldl -/ - -theorem dfoldl_loop_lt (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (h : i < n) (x) : - dfoldl.loop n α f i (Nat.lt_add_right 1 h) x = - dfoldl.loop n α f (i+1) (Nat.add_lt_add_right h 1) (f ⟨i, h⟩ x) := by - rw [dfoldl.loop, dif_pos h] - -theorem dfoldl_loop_eq (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (x) : - dfoldl.loop n α f n (Nat.le_refl _) x = x := by - rw [dfoldl.loop, dif_neg (Nat.lt_irrefl _), cast_eq] +theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : + dfoldr n α f x = dfoldrM (m:=Id) n α f x := rfl -@[simp] theorem dfoldl_zero (f : (i : Fin 0) → α i.castSucc → α i.succ) (x) : - dfoldl 0 α f x = x := dfoldl_loop_eq .. +/-! ### dfoldr -/ -theorem dfoldl_loop (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (h : i < n+1) (x) : - dfoldl.loop (n+1) α f i (Nat.lt_add_right 1 h) x = - dfoldl.loop n (α ∘ succ) (f ·.succ ·) i h (f ⟨i, h⟩ x) := by - if h' : i < n then - rw [dfoldl_loop_lt _ h _] - rw [dfoldl_loop_lt _ h' _, dfoldl_loop]; rfl - else - cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') - rw [dfoldl_loop_lt] - rw [dfoldl_loop_eq, dfoldl_loop_eq] +@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) : + dfoldr 0 α f x = x := rfl -theorem dfoldl_succ (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) : - dfoldl (n+1) α f x = dfoldl n (α ∘ succ) (f ·.succ ·) (f 0 x) := dfoldl_loop .. +theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : + dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldrM_succ .. -theorem dfoldl_succ_last (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) : - dfoldl (n+1) α f x = f (last n) (dfoldl n (α ∘ castSucc) (f ·.castSucc ·) x) := by - rw [dfoldl_succ] +theorem dfoldr_succ_last {n : Nat} {α : Fin (n+2) → Sort _} + (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x : α (last (n+1))) : + dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by induction n with - | zero => simp [dfoldl_succ, last] - | succ n ih => rw [dfoldl_succ, @ih (α ∘ succ) (f ·.succ ·), dfoldl_succ]; congr + | zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta] + | succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr -theorem dfoldl_eq_foldl (f : Fin n → α → α) (x : α) : - dfoldl n (fun _ => α) f x = foldl n (fun x i => f i x) x := by - induction n generalizing x with - | zero => simp only [dfoldl_zero, foldl_zero] - | succ n ih => - simp only [dfoldl_succ, foldl_succ, Function.comp_apply, Function.comp_def] - congr; simp only [ih] +theorem dfoldr_eq_foldr (f : (i : Fin n) → α → α) (x) : + dfoldr n (fun _ => α) f x = foldr n f x := by + induction n with + | zero => simp only [dfoldr_zero, foldr_zero] + | succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_def, ih] /-! ### dfoldlM -/ @@ -113,7 +76,7 @@ theorem dfoldlM_loop_eq [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α rw [dfoldlM.loop, dif_neg (Nat.lt_irrefl _), cast_eq] @[simp] theorem dfoldlM_zero [Monad m] (f : (i : Fin 0) → α i.castSucc → m (α i.succ)) (x) : - dfoldlM 0 α f x = pure x := dfoldlM_loop_eq .. + dfoldlM 0 α f x = pure x := rfl theorem dfoldlM_loop [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (h : i < n+1) (x) : dfoldlM.loop (n+1) α f i (Nat.lt_add_right 1 h) x = @@ -140,6 +103,32 @@ theorem dfoldlM_eq_foldlM [Monad m] (f : (i : Fin n) → α → m α) (x : α) : simp only [dfoldlM_succ, foldlM_succ, Function.comp_apply, Function.comp_def] congr; ext; simp only [ih] +/-! ### dfoldl -/ + +@[simp] theorem dfoldl_zero (f : (i : Fin 0) → α i.castSucc → α i.succ) (x) : + dfoldl 0 α f x = x := rfl + +theorem dfoldl_succ (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) : + dfoldl (n+1) α f x = dfoldl n (α ∘ succ) (f ·.succ ·) (f 0 x) := dfoldlM_succ .. + +theorem dfoldl_succ_last (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) : + dfoldl (n+1) α f x = f (last n) (dfoldl n (α ∘ castSucc) (f ·.castSucc ·) x) := by + rw [dfoldl_succ] + induction n with + | zero => simp [dfoldl_succ, last] + | succ n ih => rw [dfoldl_succ, @ih (α ∘ succ) (f ·.succ ·), dfoldl_succ]; congr + +theorem dfoldl_eq_dfoldlM (f : (i : Fin n) → α i.castSucc → α i.succ) (x) : + dfoldl n α f x = dfoldlM (m := Id) n α f x := rfl + +theorem dfoldl_eq_foldl (f : Fin n → α → α) (x : α) : + dfoldl n (fun _ => α) f x = foldl n (fun x i => f i x) x := by + induction n generalizing x with + | zero => simp only [dfoldl_zero, foldl_zero] + | succ n ih => + simp only [dfoldl_succ, foldl_succ, Function.comp_apply, Function.comp_def] + congr; simp only [ih] + /-! ### `Fin.fold{l/r}{M}` equals `List.fold{l/r}{M}` -/ theorem foldlM_eq_foldlM_finRange [Monad m] (f : α → Fin n → m α) (x) :