Skip to content

Commit

Permalink
chore: upstream the once-per-file Cache for tactics (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Nov 1, 2023
1 parent 026b17c commit d3ce3e8
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 10 deletions.
1 change: 1 addition & 0 deletions Std.lean
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ import Std.Tactic.TryThis
import Std.Tactic.Unreachable
import Std.Tactic.Where
import Std.Test.Internal.DummyLabelAttr
import Std.Util.Cache
import Std.Util.ExtendedBinder
import Std.Util.LibraryNote
import Std.Util.Pickle
Expand Down
45 changes: 35 additions & 10 deletions Std/Lean/Meta/DiscrTree.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/-
Copyright (c) 2022 Jannis Limperg. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jannis Limperg
Authors: Jannis Limperg, Scott Morrison
-/

import Lean.Meta.DiscrTree
Expand All @@ -26,8 +26,7 @@ protected def cmp : Key → Key → Ordering
s₁.quickCmp s₂ |>.then <| compare i₁ i₂ |>.then <| compare a₁ a₂
| k₁, k₂ => compare k₁.ctorIdx k₂.ctorIdx

instance : Ord Key :=
⟨Key.cmp⟩
instance : Ord Key := ⟨Key.cmp⟩

end Key

Expand Down Expand Up @@ -55,17 +54,15 @@ opaque foldM [Monad m] (initalKeys : Array Key)
Fold the keys and values stored in a `Trie`.
-/
@[inline]
def fold (initialKeys : Array Key) (f : σ → Array Key → α → σ)
(init : σ) (t : Trie α) : σ :=
def fold (initialKeys : Array Key) (f : σ → Array Key → α → σ) (init : σ) (t : Trie α) : σ :=
Id.run <| t.foldM initialKeys (init := init) fun s k a => return f s k a

-- This is just a partial function, but Lean doesn't realise that its type is
-- inhabited.
private unsafe def foldValuesMUnsafe [Monad m] (f : σ → α → m σ) (init : σ) :
Trie α → m σ
| node vs children => do
let s ← vs.foldlM (init := init) f
children.foldlM (init := s) fun s (_, c) => c.foldValuesMUnsafe (init := s) f
private unsafe def foldValuesMUnsafe [Monad m] (f : σ → α → m σ) (init : σ) : Trie α → m σ
| node vs children => do
let s ← vs.foldlM (init := init) f
children.foldlM (init := s) fun s (_, c) => c.foldValuesMUnsafe (init := s) f

/--
Monadically fold the values stored in a `Trie`.
Expand Down Expand Up @@ -162,3 +159,31 @@ Merge two `DiscrTree`s. Duplicate values are preserved.
def mergePreservingDuplicates (t u : DiscrTree α) : DiscrTree α :=
⟨t.root.mergeWith u.root fun _ trie₁ trie₂ =>
trie₁.mergePreservingDuplicates trie₂⟩

/--
Inserts a new key into a discrimination tree,
but only if it is not of the form `#[*]` or `#[=, *, *, *]`.
-/
def insertIfSpecific [BEq α] (d : DiscrTree α)
(keys : Array DiscrTree.Key) (v : α) (config : WhnfCoreConfig) : DiscrTree α :=
if keys == #[Key.star] || keys == #[Key.const `Eq 3, Key.star, Key.star, Key.star] then
d
else
d.insertCore keys v config

variable {m : TypeType} [Monad m]

/-- Apply a monadic function to the array of values at each node in a `DiscrTree`. -/
partial def Trie.mapArraysM (t : DiscrTree.Trie α) (f : Array α → m (Array β)) :
m (DiscrTree.Trie β) :=
match t with
| .node vs children =>
return .node (← f vs) (← children.mapM fun (k, t') => do pure (k, ← t'.mapArraysM f))

/-- Apply a monadic function to the array of values at each node in a `DiscrTree`. -/
def mapArraysM (d : DiscrTree α) (f : Array α → m (Array β)) : m (DiscrTree β) := do
pure { root := ← d.root.mapM (fun t => t.mapArraysM f) }

/-- Apply a function to the array of values at each node in a `DiscrTree`. -/
def mapArrays (d : DiscrTree α) (f : Array α → Array β) : DiscrTree β :=
Id.run <| d.mapArraysM fun A => pure (f A)
160 changes: 160 additions & 0 deletions Std/Util/Cache.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/-
Copyright (c) 2021 Gabriel Ebner. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
import Std.Lean.Meta.DiscrTree

/-!
# Once-per-file cache for tactics
This file defines cache data structures for tactics
that are initialized the first time they are accessed.
Since Lean 4 starts one process per file, these caches are once-per-file
and can for example be used to cache information about the imported modules.
The `Cache α` data structure is the most generic version we define.
It is created using `Cache.mk f` where `f : MetaM α` performs the initialization of the cache:
```
initialize numberOfImports : Cache Nat ← Cache.mk do
(← getEnv).imports.size
-- (does not work in the same module where the cache is defined)
#eval show MetaM Nat from numberOfImports.get
```
The `DeclCache α` data structure computes a fold over the environment's constants:
`DeclCache.mk empty f` constructs such a cache
where `empty : α` and `f : Name → ConstantInfo → α → MetaM α`.
The result of the constants in the imports is cached between tactic invocations,
while for constants defined in the same file `f` is evaluated again every time.
This kind of cache can be used e.g. to populate discrimination trees.
-/

open Lean Meta

namespace Mathlib.Tactic

/-- Once-per-file cache. -/
def Cache (α : Type) := IO.Ref <| MetaM α ⊕ Task (Except Exception α)

-- This instance is required as we use `Cache` with `initialize`.
-- One might expect an `Inhabited` instance here,
-- but there is no way to construct such without using choice anyway.
instance : Nonempty (Cache α) := inferInstanceAs <| Nonempty (IO.Ref _)

/-- Creates a cache with an initialization function. -/
def Cache.mk (init : MetaM α) : IO (Cache α) := IO.mkRef <| Sum.inl init

/--
Access the cache. Calling this function for the first time will initialize the cache
with the function provided in the constructor.
-/
def Cache.get [Monad m] [MonadEnv m] [MonadLog m] [MonadOptions m] [MonadLiftT BaseIO m]
[MonadExcept Exception m] (cache : Cache α) : m α := do
let t ← match ← ST.Ref.get (m := BaseIO) cache with
| .inr t => pure t
| .inl init =>
let env ← getEnv
let fileName ← getFileName
let fileMap ← getFileMap
let options ← getOptions -- TODO: sanitize options?
-- Default heartbeats to a reasonable value.
-- otherwise exact? times out on mathlib
-- TODO: add customization option
let options := Core.maxHeartbeats.set options <|
options.get? Core.maxHeartbeats.name |>.getD 1000000
let res ← EIO.asTask <|
init {} |>.run' {} { options, fileName, fileMap } |>.run' { env }
cache.set (m := BaseIO) (.inr res)
pure res
match t.get with
| Except.ok res => pure res
| Except.error err => throw err

/--
Cached fold over the environment's declarations,
where a given function is applied to `α` for every constant.
-/
structure DeclCache (α : Type) where mk' ::
/-- The cached data. -/
cache : Cache α
/-- Function for adding a declaration from the current file to the cache. -/
addDecl : Name → ConstantInfo → α → MetaM α
/-- Function for adding a declaration from the library to the cache.
Defaults to the same behaviour as adding a declaration from the current file. -/
addLibraryDecl : Name → ConstantInfo → α → MetaM α := addDecl
deriving Nonempty

/--
Creates a `DeclCache`.
The cached structure `α` is initialized with `empty`,
and then `addLibraryDecl` is called for every imported constant, and the result is cached.
After all imported constants have been added, we run `post`.
When `get` is called, `addDecl` is also called for every constant in the current file.
-/
def DeclCache.mk (profilingName : String) (empty : α)
(addDecl : Name → ConstantInfo → α → MetaM α)
(addLibraryDecl : Name → ConstantInfo → α → MetaM α := addDecl)
(post : α → MetaM α := fun a => pure a) : IO (DeclCache α) := do
let cache ← Cache.mk do
profileitM Exception profilingName (← getOptions) do
post <|← (← getEnv).constants.map₁.foldM (init := empty) fun a n c =>
addLibraryDecl n c a
pure { cache, addDecl }

/--
Access the cache. Calling this function for the first time will initialize the cache
with the function provided in the constructor.
-/
def DeclCache.get (cache : DeclCache α) : MetaM α := do
(← getEnv).constants.map₂.foldlM (init := ← cache.cache.get) fun a n c =>
cache.addDecl n c a

/--
A type synonym for a `DeclCache` containing a pair of discrimination trees.
The first will store declarations in the current file,
the second will store declarations from imports (and will hopefully be "read-only" after creation).
-/
@[reducible] def DiscrTreeCache (α : Type) : Type := DeclCache (DiscrTree α × DiscrTree α)

/-- Discrimation tree settings for the `DiscrTreeCache`. -/
def DiscrTreeCache.config : WhnfCoreConfig := {}

/--
Build a `DiscrTreeCache`,
from a function that returns a collection of keys and values for each declaration.
-/
def DiscrTreeCache.mk [BEq α] (profilingName : String)
(processDecl : Name → ConstantInfo → MetaM (Array (Array DiscrTree.Key × α)))
(post? : Option (Array α → Array α) := none)
(init : Option (DiscrTree α) := none) :
IO (DiscrTreeCache α) :=
let updateTree := fun name constInfo tree => do
return (← processDecl name constInfo).foldl (init := tree) fun t (k, v) =>
t.insertIfSpecific k v config
let addDecl := fun name constInfo (tree₁, tree₂) =>
return (← updateTree name constInfo tree₁, tree₂)
let addLibraryDecl := fun name constInfo (tree₁, tree₂) =>
return (tree₁, ← updateTree name constInfo tree₂)
let post := match post? with
| some f => fun (T₁, T₂) => return (T₁, T₂.mapArrays f)
| none => fun T => pure T
match init with
| some t => return ⟨← Cache.mk (pure ({}, t)), addDecl, addLibraryDecl⟩
| none => DeclCache.mk profilingName ({}, {}) addDecl addLibraryDecl (post := post)

/--
Get matches from both the discrimination tree for declarations in the current file,
and for the imports.
Note that if you are calling this multiple times with the same environment,
it will rebuild the discrimination tree for the current file multiple times,
and it would be more efficient to call `c.get` once,
and then call `DiscrTree.getMatch` multiple times.
-/
def DiscrTreeCache.getMatch (c : DiscrTreeCache α) (e : Expr) : MetaM (Array α) := do
let (locals, imports) ← c.get
-- `DiscrTree.getMatch` returns results in batches, with more specific lemmas coming later.
-- Hence we reverse this list, so we try out more specific lemmas earlier.
return (← locals.getMatch e config).reverse ++ (← imports.getMatch e config).reverse

0 comments on commit d3ce3e8

Please sign in to comment.