Skip to content

Commit

Permalink
fix: revert "monadic generalization of FindExpr" (#4125)
Browse files Browse the repository at this point in the history
This reverts commit 706a4cf introduced
in #3970

As explained in #4124, `findM?` can become a footgun if used in monads
which induce side-effects such as caching. This PR removes that
function, and fixes the code introduced by #3398 for which the function
was first added.

cc @semorrison.
  • Loading branch information
arthur-adjedj authored May 10, 2024
1 parent 228ff58 commit 6c6b56e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 33 deletions.
16 changes: 8 additions & 8 deletions src/Lean/Meta/Injective.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def elimOptParam (type : Expr) : CoreM Expr := do
Instead of checking the type of every subterm, we only need to check the type of free variables, since free variables introduced in
the constructor may only appear in the type of other free variables introduced after them.
-/
def occursOrInType (e : Expr) (t : Expr) : MetaM Bool := do
let_fun f (s : Expr) := do
if !s.isFVar then
return s == e
let ty ← inferType s
return s == e || e.occurs ty
return (← t.findM? f).isSome
def occursOrInType (lctx : LocalContext) (e : Expr) (t : Expr) : Bool :=
t.find? go |>.isSome
where
go s := Id.run do
let .fvar fvarId := s | s == e
let some decl := lctx.find? fvarId | s == e
return s == e || e.occurs decl.type

private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
let us := ctorVal.levelParams.map mkLevelParam
Expand Down Expand Up @@ -87,7 +87,7 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE
match (← whnf type) with
| Expr.forallE n d b _ =>
let arg1 := args1.get ⟨i, h⟩
if ← occursOrInType arg1 resultType then
if occursOrInType (← getLCtx) arg1 resultType then
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
else
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
Expand Down
44 changes: 19 additions & 25 deletions src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ namespace Lean
namespace Expr
namespace FindImpl

unsafe abbrev FindM (m) := StateT (PtrSet Expr) m
unsafe abbrev FindM := StateT (PtrSet Expr) Id

@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do
@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
if (← get).contains e then
failure
modify fun s => s.insert e

unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr :=
unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
let rec visit (e : Expr) := do
checkVisited e
if p e then
if p e then
pure e
else match e with
| .forallE _ d b _ => visit d <|> visit b
Expand All @@ -33,35 +33,29 @@ unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m)
| _ => failure
visit e

unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) :=
findM? p e |>.run' mkPtrSet

@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e
unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr :=
Id.run <| findM? p e |>.run' mkPtrSet

end FindImpl

@[implemented_by FindImpl.findUnsafeM?]
/- This is a reference implementation for the unsafe one above -/
def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do
if ← p e then
return some e
else match e with
| .forallE _ d b _ => findM? p d <||> findM? p b
| .lam _ d b _ => findM? p d <||> findM? p b
| .mdata _ b => findM? p b
| .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b
| .app f a => findM? p f <||> findM? p a
| .proj _ _ b => findM? p b
| _ => pure none

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e
def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
/- This is a reference implementation for the unsafe one above -/
if p e then
some e
else match e with
| .forallE _ d b _ => find? p d <|> find? p b
| .lam _ d b _ => find? p d <|> find? p b
| .mdata _ b => find? p b
| .letE _ t v b _ => find? p t <|> find? p v <|> find? p b
| .app f a => find? p f <|> find? p a
| .proj _ _ b => find? p b
| _ => none

/-- Return true if `e` occurs in `t` -/
def occurs (e : Expr) (t : Expr) : Bool :=
(t.find? fun s => s == e).isSome


/--
Return type for `findExt?` function argument.
-/
Expand All @@ -72,7 +66,7 @@ inductive FindStep where

namespace FindExtImpl

unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr :=
unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr :=
visit e
where
visitApp (e : Expr) :=
Expand Down

0 comments on commit 6c6b56e

Please sign in to comment.