Skip to content

Commit

Permalink
feat: @[simp] lemmas about List.toArray (#5472)
Browse files Browse the repository at this point in the history
We make sure that we can pull `List.toArray` out through all operations
(well, for now "most" rather than "all"). As we also push `Array.toList`
inwards, this hopefully has the effect of them cancelling as they meet,
and `simp` naturally rewriting Array operations into List operations
wherever possible.

This is not at all complete yet.
  • Loading branch information
kim-em authored Sep 26, 2024
1 parent 90cb6e5 commit 5dea30f
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 10 deletions.
172 changes: 162 additions & 10 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ theorem swap_def (a : Array α) (i j : Fin a.size) :
a.swap i j = (a.set i (a.get j)).set ⟨j.1, by simp [j.2]⟩ (a.get i) := by
simp [swap, fin_cast_val]

theorem toList_swap (a : Array α) (i j : Fin a.size) :
@[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) :
(a.swap i j).toList = (a.toList.set i (a.get j)).set j (a.get i) := by simp [swap_def]

@[deprecated toList_swap (since := "2024-09-09")]
Expand Down Expand Up @@ -847,7 +847,7 @@ theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :

/-! ### filter -/

@[simp] theorem filter_toList (p : α → Bool) (l : Array α) :
@[simp] theorem toList_filter (p : α → Bool) (l : Array α) :
(l.filter p).toList = l.toList.filter p := by
dsimp only [filter]
rw [foldl_eq_foldl_toList]
Expand All @@ -858,23 +858,23 @@ theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
induction l with simp
| cons => split <;> simp [*]

@[deprecated filter_toList (since := "2024-09-09")]
abbrev filter_data := @filter_toList
@[deprecated toList_filter (since := "2024-09-09")]
abbrev filter_data := @toList_filter

@[simp] theorem filter_filter (q) (l : Array α) :
filter p (filter q l) = filter (fun a => p a && q a) l := by
apply ext'
simp only [filter_toList, List.filter_filter]
simp only [toList_filter, List.filter_filter]

@[simp] theorem mem_filter : x ∈ filter p as ↔ x ∈ as ∧ p x := by
simp only [mem_def, filter_toList, List.mem_filter]
simp only [mem_def, toList_filter, List.mem_filter]

theorem mem_of_mem_filter {a : α} {l} (h : a ∈ filter p l) : a ∈ l :=
(mem_filter.mp h).1

/-! ### filterMap -/

@[simp] theorem filterMap_toList (f : α → Option β) (l : Array α) :
@[simp] theorem toList_filterMap (f : α → Option β) (l : Array α) :
(l.filterMap f).toList = l.toList.filterMap f := by
dsimp only [filterMap, filterMapM]
rw [foldlM_eq_foldlM_toList]
Expand All @@ -887,12 +887,12 @@ theorem mem_of_mem_filter {a : α} {l} (h : a ∈ filter p l) : a ∈ l :=
· simp_all [Id.run, List.filterMap_cons]
split <;> simp_all

@[deprecated filterMap_toList (since := "2024-09-09")]
abbrev filterMap_data := @filterMap_toList
@[deprecated toList_filterMap (since := "2024-09-09")]
abbrev filterMap_data := @toList_filterMap

@[simp] theorem mem_filterMap {f : α → Option β} {l : Array α} {b : β} :
b ∈ filterMap f l ↔ ∃ a, a ∈ l ∧ f a = some b := by
simp only [mem_def, filterMap_toList, List.mem_filterMap]
simp only [mem_def, toList_filterMap, List.mem_filterMap]

/-! ### empty -/

Expand Down Expand Up @@ -1071,6 +1071,33 @@ theorem extract_empty_of_size_le_start (as : Array α) {start stop : Nat} (h : a

/-! ### any -/

theorem anyM_loop_cons [Monad m] (p : α → m Bool) (a : α) (as : List α) (stop start : Nat) (h : stop + 1 ≤ (a :: as).length) :
anyM.loop p ⟨a :: as⟩ (stop + 1) h (start + 1) = anyM.loop p ⟨as⟩ stop (by simpa using h) start := by
rw [anyM.loop]
conv => rhs; rw [anyM.loop]
split <;> rename_i h'
· simp only [Nat.add_lt_add_iff_right] at h'
rw [dif_pos h']
rw [anyM_loop_cons]
simp
· rw [dif_neg]
omega

@[simp] theorem anyM_toList [Monad m] (p : α → m Bool) (as : Array α) :
as.toList.anyM p = as.anyM p :=
match as with
| ⟨[]⟩ => rfl
| ⟨a :: as⟩ => by
simp only [List.anyM, anyM, size_toArray, List.length_cons, Nat.le_refl, ↓reduceDIte]
rw [anyM.loop, dif_pos (by omega)]
congr 1
funext b
split
· simp
· simp only [Bool.false_eq_true, ↓reduceIte]
rw [anyM_loop_cons]
simpa [anyM] using anyM_toList p ⟨as⟩

-- Auxiliary for `any_iff_exists`.
theorem anyM_loop_iff_exists {p : α → Bool} {as : Array α} {start stop} (h : stop ≤ as.size) :
anyM.loop (m := Id) p as stop h start = true
Expand Down Expand Up @@ -1115,6 +1142,17 @@ theorem any_def {p : α → Bool} (as : Array α) : as.any p = as.toList.any p :

/-! ### all -/

theorem allM_eq_not_anyM_not [Monad m] [LawfulMonad m] (p : α → m Bool) (as : Array α) :
allM p as = (! ·) <$> anyM ((! ·) <$> p ·) as := by
dsimp [allM, anyM]
simp

@[simp] theorem allM_toList [Monad m] [LawfulMonad m] (p : α → m Bool) (as : Array α) :
as.toList.allM p = as.allM p := by
rw [allM_eq_not_anyM_not]
rw [← anyM_toList]
rw [List.allM_eq_not_anyM_not]

theorem all_eq_not_any_not (p : α → Bool) (as : Array α) (start stop) :
all as p start stop = !(any as (!p ·) start stop) := by
dsimp [all, allM]
Expand Down Expand Up @@ -1211,3 +1249,117 @@ theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i :=
· split <;> simp_all

end Array


open Array

namespace List

/-!
### More theorems about `List.toArray`, followed by an `Array` operation.
Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
-/

@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by
simp [mem_def]

@[simp] theorem getElem?_toArray (l : List α) (i : Nat) : l.toArray[i]? = l[i]? := by
simp [getElem?_eq_getElem?_toList]

@[simp] theorem toListRev_toArray (l : List α) : l.toArray.toListRev = l.reverse := by
simp

@[simp] theorem push_append_toArray (as : Array α) (a : α) (l : List α) :
as.push a ++ l.toArray = as ++ (a :: l).toArray := by
apply ext'
simp

@[simp] theorem mapM_toArray [Monad m] [LawfulMonad m] (f : α → m β) (l : List α) :
l.toArray.mapM f = List.toArray <$> l.mapM f := by
simp only [← mapM'_eq_mapM, mapM_eq_foldlM]
suffices ∀ init : Array β,
foldlM (fun bs a => bs.push <$> f a) init l.toArray = (init ++ toArray ·) <$> mapM' f l by
simpa using this #[]
intro init
induction l generalizing init with
| nil => simp
| cons a l ih =>
simp only [foldlM_toArray] at ih
rw [size_toArray, mapM'_cons, foldlM_toArray]
simp [ih]

@[simp] theorem map_toArray (f : α → β) (l : List α) : l.toArray.map f = (l.map f).toArray := by
apply ext'
simp

@[simp] theorem toArray_appendList (l₁ l₂ : List α) :
l₁.toArray ++ l₂ = (l₁ ++ l₂).toArray := by
apply ext'
simp

@[simp] theorem set_toArray (l : List α) (i : Fin l.toArray.size) (a : α) :
l.toArray.set i a = (l.set i a).toArray := by
apply ext'
simp

@[simp] theorem uset_toArray (l : List α) (i : USize) (a : α) (h : i.toNat < l.toArray.size) :
l.toArray.uset i a h = (l.set i.toNat a).toArray := by
apply ext'
simp

@[simp] theorem setD_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setD i a = (l.set i a).toArray := by
apply ext'
simp only [setD]
split
· simp
· simp_all [List.set_eq_of_length_le]

@[simp] theorem anyM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
l.toArray.anyM p = l.anyM p := by
rw [← anyM_toList]

@[simp] theorem any_toArray (p : α → Bool) (l : List α) : l.toArray.any p = l.any p := by
rw [Array.any_def]

@[simp] theorem allM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
l.toArray.allM p = l.allM p := by
rw [← allM_toList]

@[simp] theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p := by
rw [Array.all_def]

@[simp] theorem swap_toArray (l : List α) (i j : Fin l.toArray.size) :
l.toArray.swap i j = ((l.set i l[j]).set j l[i]).toArray := by
apply ext'
simp

@[simp] theorem pop_toArray (l : List α) : l.toArray.pop = l.dropLast.toArray := by
apply ext'
simp

@[simp] theorem reverse_toArray (l : List α) : l.toArray.reverse = l.reverse.toArray := by
apply ext'
simp

@[simp] theorem filter_toArray (p : α → Bool) (l : List α) :
l.toArray.filter p = (l.filter p).toArray := by
apply ext'
erw [toList_filter] -- `erw` required to unify `l.length` with `l.toArray.size`.

@[simp] theorem filterMap_toArray (f : α → Option β) (l : List α) :
l.toArray.filterMap f = (l.filterMap f).toArray := by
apply ext'
erw [toList_filterMap] -- `erw` required to unify `l.length` with `l.toArray.size`.

@[simp] theorem append_toArray (l₁ l₂ : List α) :
l₁.toArray ++ l₂.toArray = (l₁ ++ l₂).toArray := by
apply ext'
simp

@[simp] theorem toArray_range (n : Nat) : (range n).toArray = Array.range n := by
apply ext'
simp

end List
33 changes: 33 additions & 0 deletions src/Init/Data/List/Monadic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ theorem mapM'_eq_mapM [Monad m] [LawfulMonad m] (f : α → m β) (l : List α)
@[simp] theorem mapM_append [Monad m] [LawfulMonad m] (f : α → m β) {l₁ l₂ : List α} :
(l₁ ++ l₂).mapM f = (return (← l₁.mapM f) ++ (← l₂.mapM f)) := by induction l₁ <;> simp [*]

/-- Auxiliary lemma for `mapM_eq_reverse_foldlM_cons`. -/
theorem foldlM_cons_eq_append [Monad m] [LawfulMonad m] (f : α → m β) (as : List α) (b : β) (bs : List β) :
(as.foldlM (init := b :: bs) fun acc a => return ((← f a) :: acc)) =
(· ++ b :: bs) <$> as.foldlM (init := []) fun acc a => return ((← f a) :: acc) := by
induction as generalizing b bs with
| nil => simp
| cons a as ih =>
simp only [bind_pure_comp] at ih
simp [ih, _root_.map_bind, Functor.map_map, Function.comp_def]

theorem mapM_eq_reverse_foldlM_cons [Monad m] [LawfulMonad m] (f : α → m β) (l : List α) :
mapM f l = reverse <$> (l.foldlM (fun acc a => return ((← f a) :: acc)) []) := by
rw [← mapM'_eq_mapM]
induction l with
| nil => simp
| cons a as ih =>
simp only [mapM'_cons, ih, bind_map_left, foldlM_cons, LawfulMonad.bind_assoc, pure_bind,
foldlM_cons_eq_append, _root_.map_bind, Functor.map_map, Function.comp_def, reverse_append,
reverse_cons, reverse_nil, nil_append, singleton_append]
simp [bind_pure_comp]

/-! ### forM -/

-- We use `List.forM` as the simp normal form, rather that `ForM.forM`.
Expand All @@ -66,4 +87,16 @@ theorem mapM'_eq_mapM [Monad m] [LawfulMonad m] (f : α → m β) (l : List α)
(l₁ ++ l₂).forM f = (do l₁.forM f; l₂.forM f) := by
induction l₁ <;> simp [*]

/-! ### allM -/

theorem allM_eq_not_anyM_not [Monad m] [LawfulMonad m] (p : α → m Bool) (as : List α) :
allM p as = (! ·) <$> anyM ((! ·) <$> p ·) as := by
induction as with
| nil => simp
| cons a as ih =>
simp only [allM, anyM, bind_map_left, _root_.map_bind]
congr
funext b
split <;> simp_all

end List

0 comments on commit 5dea30f

Please sign in to comment.