Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: revert "monadic generalization of FindExpr" #4125

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/Lean/Meta/Injective.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def elimOptParam (type : Expr) : CoreM Expr := do
Instead of checking the type of every subterm, we only need to check the type of free variables, since free variables introduced in
the constructor may only appear in the type of other free variables introduced after them.
-/
def occursOrInType (e : Expr) (t : Expr) : MetaM Bool := do
let_fun f (s : Expr) := do
if !s.isFVar then
return s == e
let ty ← inferType s
return s == e || e.occurs ty
return (← t.findM? f).isSome
def occursOrInType (lctx : LocalContext) (e : Expr) (t : Expr) : Bool :=
t.find? go |>.isSome
where
go s := Id.run do
let .fvar fvarId := s | s == e
let some decl := lctx.find? fvarId | s == e
return s == e || e.occurs decl.type

private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
let us := ctorVal.levelParams.map mkLevelParam
Expand Down Expand Up @@ -87,7 +87,7 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE
match (← whnf type) with
| Expr.forallE n d b _ =>
let arg1 := args1.get ⟨i, h⟩
if ← occursOrInType arg1 resultType then
if occursOrInType (← getLCtx) arg1 resultType then
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
else
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
Expand Down
44 changes: 19 additions & 25 deletions src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ namespace Lean
namespace Expr
namespace FindImpl

unsafe abbrev FindM (m) := StateT (PtrSet Expr) m
unsafe abbrev FindM := StateT (PtrSet Expr) Id

@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do
@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
if (← get).contains e then
failure
modify fun s => s.insert e

unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr :=
unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
let rec visit (e : Expr) := do
checkVisited e
if p e then
if p e then
pure e
else match e with
| .forallE _ d b _ => visit d <|> visit b
Expand All @@ -33,35 +33,29 @@ unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m)
| _ => failure
visit e

unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) :=
findM? p e |>.run' mkPtrSet

@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e
unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr :=
Id.run <| findM? p e |>.run' mkPtrSet

end FindImpl

@[implemented_by FindImpl.findUnsafeM?]
/- This is a reference implementation for the unsafe one above -/
def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do
if ← p e then
return some e
else match e with
| .forallE _ d b _ => findM? p d <||> findM? p b
| .lam _ d b _ => findM? p d <||> findM? p b
| .mdata _ b => findM? p b
| .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b
| .app f a => findM? p f <||> findM? p a
| .proj _ _ b => findM? p b
| _ => pure none

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e
def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
/- This is a reference implementation for the unsafe one above -/
if p e then
some e
else match e with
| .forallE _ d b _ => find? p d <|> find? p b
| .lam _ d b _ => find? p d <|> find? p b
| .mdata _ b => find? p b
| .letE _ t v b _ => find? p t <|> find? p v <|> find? p b
| .app f a => find? p f <|> find? p a
| .proj _ _ b => find? p b
| _ => none

/-- Return true if `e` occurs in `t` -/
def occurs (e : Expr) (t : Expr) : Bool :=
(t.find? fun s => s == e).isSome


/--
Return type for `findExt?` function argument.
-/
Expand All @@ -72,7 +66,7 @@ inductive FindStep where

namespace FindExtImpl

unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr :=
unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr :=
visit e
where
visitApp (e : Expr) :=
Expand Down
Loading