Skip to content

Commit

Permalink
feat: adjust simp attributes on monad lemmas (#5464)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Sep 25, 2024
1 parent 1defa20 commit c2f6297
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
23 changes: 14 additions & 9 deletions src/Init/Control/Lawful/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ attribute [simp] id_map
@[simp] theorem id_map' [Functor m] [LawfulFunctor m] (x : m α) : (fun a => a) <$> x = x :=
id_map x

theorem Functor.map_map [Functor f] [LawfulFunctor f] (m : α → β) (g : β → γ) (x : f α) :
g <$> m <$> x = (g ∘ m) <$> x :=
@[simp] theorem Functor.map_map [Functor f] [LawfulFunctor f] (m : α → β) (g : β → γ) (x : f α) :
g <$> m <$> x = (fun a => g (m a)) <$> x :=
(comp_map _ _ _).symm

/--
Expand Down Expand Up @@ -87,12 +87,16 @@ class LawfulMonad (m : Type u → Type v) [Monad m] extends LawfulApplicative m
seq_assoc x g h := (by simp [← bind_pure_comp, ← bind_map, bind_assoc, pure_bind])

export LawfulMonad (bind_pure_comp bind_map pure_bind bind_assoc)
attribute [simp] pure_bind bind_assoc
attribute [simp] pure_bind bind_assoc bind_pure_comp

@[simp] theorem bind_pure [Monad m] [LawfulMonad m] (x : m α) : x >>= pure = x := by
show x >>= (fun a => pure (id a)) = x
rw [bind_pure_comp, id_map]

/--
Use `simp [← bind_pure_comp]` rather than `simp [map_eq_pure_bind]`,
as `bind_pure_comp` is in the default simp set, so also using `map_eq_pure_bind` would cause a loop.
-/
theorem map_eq_pure_bind [Monad m] [LawfulMonad m] (f : α → β) (x : m α) : f <$> x = x >>= fun a => pure (f a) := by
rw [← bind_pure_comp]

Expand All @@ -113,20 +117,21 @@ theorem seq_eq_bind {α β : Type u} [Monad m] [LawfulMonad m] (mf : m (α →

theorem seqRight_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x *> y = x >>= fun _ => y := by
rw [seqRight_eq]
simp [map_eq_pure_bind, seq_eq_bind_map, const]
simp only [map_eq_pure_bind, const, seq_eq_bind_map, bind_assoc, pure_bind, id_eq, bind_pure]

theorem seqLeft_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x <* y = x >>= fun a => y >>= fun _ => pure a := by
rw [seqLeft_eq]; simp [map_eq_pure_bind, seq_eq_bind_map]
rw [seqLeft_eq]
simp only [map_eq_pure_bind, seq_eq_bind_map, bind_assoc, pure_bind, const_apply]

theorem map_bind [Monad m] [LawfulMonad m] (x : m α) {g : α → m β} {f : β → γ} :
f <$> (x >>= fun a => g a) = x >>= fun a => f <$> g a := by
@[simp] theorem map_bind [Monad m] [LawfulMonad m] (f : β → γ) (x : m α) (g : α → m β) :
f <$> (x >>= g) = x >>= fun a => f <$> g a := by
rw [← bind_pure_comp, LawfulMonad.bind_assoc]
simp [bind_pure_comp]

theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) :
@[simp] theorem bind_map_left [Monad m] [LawfulMonad m] (f : α → β) (x : m α) (g : β → m γ) :
((f <$> x) >>= fun b => g b) = (x >>= fun a => g (f a)) := by
rw [← bind_pure_comp]
simp [bind_assoc, pure_bind]
simp only [bind_assoc, pure_bind]

/--
An alternative constructor for `LawfulMonad` which has more
Expand Down
14 changes: 7 additions & 7 deletions src/Init/Control/Lawful/Instances.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by
@[simp] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl

@[simp] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α → ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by
simp[ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont, map_eq_pure_bind]
simp [ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont]

@[simp] theorem bind_throw [Monad m] [LawfulMonad m] (f : α → ExceptT ε m β) : (throw e >>= f) = throw e := by
simp [throw, throwThe, MonadExceptOf.throw, bind, ExceptT.bind, ExceptT.bindCont, ExceptT.mk]
Expand All @@ -43,7 +43,7 @@ theorem run_bind [Monad m] (x : ExceptT ε m α)

@[simp] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α)
: (f <$> x).run = Except.map f <$> x.run := by
simp [Functor.map, ExceptT.map, map_eq_pure_bind]
simp [Functor.map, ExceptT.map, ←bind_pure_comp]
apply bind_congr
intro a; cases a <;> simp [Except.map]

Expand All @@ -62,7 +62,7 @@ protected theorem seqLeft_eq {α β ε : Type u} {m : Type u → Type v} [Monad
intro
| Except.error _ => simp
| Except.ok _ =>
simp [map_eq_pure_bind]; apply bind_congr; intro b;
simp [←bind_pure_comp]; apply bind_congr; intro b;
cases b <;> simp [comp, Except.map, const]

protected theorem seqRight_eq [Monad m] [LawfulMonad m] (x : ExceptT ε m α) (y : ExceptT ε m β) : x *> y = const α id <$> x <*> y := by
Expand Down Expand Up @@ -175,7 +175,7 @@ theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y :=
simp [bind, StateT.bind, run]

@[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by
simp [Functor.map, StateT.map, run, map_eq_pure_bind]
simp [Functor.map, StateT.map, run, ←bind_pure_comp]

@[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl

Expand Down Expand Up @@ -210,21 +210,21 @@ theorem run_bind_lift {α σ : Type u} [Monad m] [LawfulMonad m] (x : m α) (f :

theorem seqRight_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x *> y = const α id <$> x <*> y := by
apply ext; intro s
simp [map_eq_pure_bind, const]
simp [←bind_pure_comp, const]
apply bind_congr; intro p; cases p
simp [Prod.eta]

theorem seqLeft_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x <* y = const β <$> x <*> y := by
apply ext; intro s
simp [map_eq_pure_bind]
simp [←bind_pure_comp]

instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where
id_map := by intros; apply ext; intros; simp[Prod.eta]
map_const := by intros; rfl
seqLeft_eq := seqLeft_eq
seqRight_eq := seqRight_eq
pure_seq := by intros; apply ext; intros; simp
bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp
bind_pure_comp := by intros; apply ext; intros; simp
bind_map := by intros; rfl
pure_bind := by intros; apply ext; intros; simp
bind_assoc := by intros; apply ext; intros; simp
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ theorem mapM_eq_mapM_toList [Monad m] [LawfulMonad m] (f : α → m β) (arr : A
conv => rhs; rw [← List.reverse_reverse arr.toList]
induction arr.toList.reverse with
| nil => simp
| cons a l ih => simp [ih]; simp [map_eq_pure_bind]
| cons a l ih => simp [ih]

@[deprecated mapM_eq_mapM_toList (since := "2024-09-09")]
abbrev mapM_eq_mapM_data := @mapM_eq_mapM_toList
Expand Down
5 changes: 5 additions & 0 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,11 @@ theorem filterMap_eq_cons_iff {l} {b} {bs} :

/-! ### append -/

@[simp] theorem nil_append_fun : (([] : List α) ++ ·) = id := rfl

@[simp] theorem cons_append_fun (a : α) (as : List α) :
(fun bs => ((a :: as) ++ bs)) = fun bs => a :: (as ++ bs) := rfl

theorem getElem_append {l₁ l₂ : List α} (n : Nat) (h) :
(l₁ ++ l₂)[n] = if h' : n < l₁.length then l₁[n] else l₂[n - l₁.length]'(by simp at h h'; exact Nat.sub_lt_left_of_lt_add h' h) := by
split <;> rename_i h'
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/do_eqv_proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ theorem ex1 [Monad m] [LawfulMonad m] (b : Bool) (ma : m α) (mb : α → m α)
(ma >>= fun x => if b then mb x else pure x) := by
cases b <;> simp

attribute [simp] map_eq_pure_bind seq_eq_bind_map
attribute [simp] seq_eq_bind_map

theorem ex2 [Monad m] [LawfulMonad m] (b : Bool) (ma : m α) (mb : α → m α) (a : α) :
(do let mut x ← ma
Expand Down

0 comments on commit c2f6297

Please sign in to comment.