Skip to content

Commit

Permalink
Merge pull request leanprover-community#27 from goens/egg-tags
Browse files Browse the repository at this point in the history
Add `@egg` tags
  • Loading branch information
marcusrossel authored Jun 15, 2024
2 parents 1ad7ef8 + b6bcf44 commit c25d9c9
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 6 deletions.
1 change: 1 addition & 0 deletions Lean/Egg.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ import Egg.Tactic.Basic
import Egg.Tactic.Calc
import Egg.Tactic.Guides
import Egg.Tactic.Trace
import Egg.Tactic.Tags
1 change: 1 addition & 0 deletions Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ structure Encoding extends Normalization where

structure Gen where
builtins := true
tagged? := some `egg
genTcProjRws := true
genTcSpecRws := true
genGoalTcSpec := true
Expand Down
2 changes: 2 additions & 0 deletions Lean/Egg/Core/Explanation/Parse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ syntax "*" noWs num : egg_basic_fwd_rw_src
syntax "⊢" : egg_basic_fwd_rw_src
syntax "↣" noWs num : egg_basic_fwd_rw_src
syntax "◯" noWs num : egg_basic_fwd_rw_src
syntax "□" noWs num (noWs "/" noWs num)? : egg_basic_fwd_rw_src

syntax "[" egg_tc_proj_loc num "," num "]" : egg_tc_proj

Expand Down Expand Up @@ -119,6 +120,7 @@ private def parseTcProjLocation : (TSyntax `egg_tc_proj_loc) → Source.TcProjLo

private def parseBasicFwdRwSrc : (TSyntax `egg_basic_fwd_rw_src) → Source
| `(egg_basic_fwd_rw_src|#$idx$[/$eqn?]?) => .explicit idx.getNat (eqn?.map TSyntax.getNat)
| `(egg_basic_fwd_rw_src|□$idx$[/$eqn?]?) => .tagged idx.getNat (eqn?.map TSyntax.getNat)
| `(egg_basic_fwd_rw_src|*$idx) => .star (.fromUniqueIdx idx.getNat)
| `(egg_basic_fwd_rw_src|⊢) => .goal
| `(egg_basic_fwd_rw_src|↣$idx) => .guide idx.getNat
Expand Down
14 changes: 8 additions & 6 deletions Lean/Egg/Tactic/Premises/Gen.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ namespace Egg.Premises
-- TODO: Perform pruning during generation, not after.

private def tracePremises
(basic : WithSyntax Rewrites) (builtins tc pruned : Rewrites) (facts : WithSyntax Facts)
(basic : WithSyntax Rewrites) (tagged builtins tc pruned : Rewrites) (facts : WithSyntax Facts)
(cfg : Config.Gen) : TacticM Unit := do
let cls := `egg.rewrites
withTraceNode cls (fun _ => return "Rewrites") do
withTraceNode cls (fun _ => return m!"Basic ({basic.elems.size})") do basic.elems.trace basic.stxs cls
withTraceNode cls (fun _ => return m!"Tagged ({tagged.size})") do tagged.trace #[] cls
withTraceNode cls (fun _ => return m!"Generated ({tc.size})") do tc.trace #[] cls
withTraceNode cls (fun _ => return m!"Builtin ({builtins.size})") do builtins.trace #[] cls
withTraceNode cls (fun _ => return m!"Hypotheses ({facts.elems.size})") do
Expand All @@ -30,14 +31,15 @@ private def tracePremises
partial def gen
(goal : Congr) (ps : TSyntax `egg_premises) (guides : Guides) (cfg : Config)
(amb : MVars.Ambient) : TacticM (Rewrites × Facts) := do
let tagged ← Premises.buildTagged cfg amb
let ⟨⟨basic, basicStxs⟩, facts⟩ ← Premises.elab { norm? := cfg, amb } ps
let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := #[])
let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := tagged)
let builtins ← if cfg.builtins then Rewrites.builtins { norm? := cfg, amb } else pure #[]
let (builtins, _, pruned₂) ← prune builtins (remove := basic)
let (builtins, _, pruned₂) ← prune builtins (remove := tagged ++ basic)
let tc ← genTcRws (basic ++ builtins) facts.elems
let (tc, _, pruned₃) ← prune tc (remove := basic ++ builtins)
tracePremises ⟨basic, basicStxs⟩ builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg
let rws := basic ++ builtins ++ tc
let (tc, _, pruned₃) ← prune tc (remove := tagged ++ basic ++ builtins)
tracePremises ⟨basic, basicStxs⟩ tagged builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg
let rws := tagged ++ basic ++ builtins ++ tc
catchInvalidConditionals rws
return (rws, facts.elems)
where
Expand Down
15 changes: 15 additions & 0 deletions Lean/Egg/Tactic/Premises/Parse.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Egg.Core.Premise.Rewrites
import Egg.Core.Premise.Facts
import Egg.Tactic.Premises.Validate
import Egg.Tactic.Tags
import Lean

open Lean Meta Elab Tactic
Expand Down Expand Up @@ -117,6 +119,19 @@ private def Premises.taggedRw (prem : Name) (idx : Nat) (cfg : Rewrite.Config) :
let rws ← Premises.explicit ident idx mk .tagged
return rws.elems

private def Premises.elabTagged (prems : Array Name) (cfg : Rewrite.Config) : TacticM Rewrites := do
let mut rws : Rewrites := #[]
for prem in prems, idx in [:prems.size] do
rws := rws ++ (← taggedRw prem idx cfg)
return rws

def Premises.buildTagged (cfg : Config) (amb : MVars.Ambient ): TacticM Rewrites :=
match cfg.tagged? with
| none => return {}
| some _ => do -- This should later use this `Name` to find the proper extension
let prems := eggXtension.getState (← getEnv)
elabTagged prems { norm? := cfg, amb}

-- Note: This function is expected to be called with the lctx which contains the desired premises.
--
-- Note: We need to filter out auxiliary declaration and implementation details, as they are not
Expand Down
53 changes: 53 additions & 0 deletions Lean/Egg/Tactic/Premises/Validate.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import Egg.Core.Premise.Rewrites

open Lean Meta Elab Tactic

inductive Premise.Raw where
| single (expr : Expr) (type? : Option Expr := none)
| eqns (exprs : Array Expr)

inductive Premise.RawRaw where
| eqns (exprs : Array Name)
| single (expr : Expr) (type? : Option Expr := none)
| prem (prem : Term)


def Premise.Raw.validate (prem : Term) : MetaM Premise.RawRaw := do
if let some const ← optional (resolveGlobalConstNoOverload prem) then
if let some eqs ← getEqnsFor? const (nonRec := true) then
-- `prem` is a global definition.
return .eqns eqs
else
-- `prem` is an global constant which is not a definition with equations.
let env ← getEnv
let some info := env.find? const | throwErrorAt prem m!"unknown constant '{mkConst const}'"
match info with
| .defnInfo _ | .axiomInfo _ | .thmInfo _ | .opaqueInfo _ =>
let lvlMVars ← List.replicateM info.numLevelParams mkFreshLevelMVar
let val := if info.hasValue then info.instantiateValueLevelParams! lvlMVars else .const info.name lvlMVars
let type := info.instantiateTypeLevelParams lvlMVars
return .single val type
| _ => throwErrorAt prem "egg requires arguments to be theorems, definitions or axioms"
else
-- `prem` is an invalid identifier or a term which is not an identifier.
-- We must use `Tactic.elabTerm`, not `Term.elabTerm`. Otherwise elaborating `‹...›` doesn't
-- work correctly. See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Elaborate.20.E2.80.B9.2E.2E.2E.E2.80.BA
return .prem prem

-- We don't just elaborate premises directly as:
-- (1) this can cause problems for global constants with typeclass arguments, as Lean sometimes
-- tries to synthesize the arguments and fails if it can't (instead of inserting mvars).
-- (2) global constants which are definitions with equations (cf. `getEqnsFor?`) are supposed to
-- be replaced by their defining equations.
partial def Premise.Raw.elab (prem : Term) : TacticM Premise.Raw := do
if let some hyp ← optional (getFVarId prem) then
-- `prem` is a local declaration.
let decl ← hyp.getDecl
if decl.isImplementationDetail || decl.isAuxDecl then
throwErrorAt prem "egg does not support using auxiliary declarations"
else
return .single (.fvar hyp) (← hyp.getType)
match (← validate prem) with
| .eqns eqs => return .eqns <| ← eqs.mapM fun eqn => Tactic.elabTerm (mkIdent eqn) none
| .single val type? => return .single val type?
| .prem prem => return .single (← Tactic.elabTerm prem none)
110 changes: 110 additions & 0 deletions Lean/Egg/Tactic/Tags.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import Egg.Tactic.Premises.Validate
import Lean

open Lean Elab Tactic Term

namespace Egg

/--
This validates that a theorem can be used by the `egg` tactic (it ultimately boils down to an equality.)
Unimplemented: Currently, this is a noop.
-/
private def validateEggTheorem (thm : Term) : MetaM Unit := do
let _ ← Premise.Raw.validate thm
return ()

-- Ideally this should be at some point a discrimination tree
abbrev EggTheorems := Array Name

abbrev EggEntry := Name -- later: Lean.Meta.SimpEntry

def addEggTheoremEntry (d : EggTheorems) (e : EggEntry) : EggTheorems :=
d.push e

abbrev EggXtension := SimpleScopedEnvExtension EggEntry EggTheorems

open Lean.Elab
open Lean.Elab.Command

def EggXtension.getTheorems (ext : EggXtension) : CoreM EggTheorems :=
return ext.getState (← getEnv)

/--
This function does the appropriate preprocessing from a `Name` to record a theorem as
an `egg` theorem.
For now, this preprocessing is nothing (just store the name in a singleton `Array`).
However, in the future this can be used like simp to do more efficient preprocessing
and deal with other kinds of `egg` lemmas (or even import `simp` lemmas).
-/
private def mkEggTheoremsFromConst (declName : Name) : MetaM EggTheorems :=
pure #[declName]

def addEggTheorem (ext : EggXtension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
let _ ← validateEggTheorem { raw := Syntax.ident default default declName []} -- ugly!
let eggThms ← mkEggTheoremsFromConst declName
for eggThm in eggThms do
ext.add eggThm attrKind

def mkEggXt (name : Name := by exact decl_name%) : IO EggXtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e => addEggTheoremEntry d e
}

def mkEggAttr (attrName : Name) (attrDescr : String) (ext : EggXtension)
(ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
ref := ref
name := attrName
descr := attrDescr
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName _stx attrKind => do
let go : MetaM Unit := do
let info ← getConstInfo declName
if (← Meta.isProp info.type) then
addEggTheorem ext declName attrKind
else
throwError "invalid 'egg', it is not a proposition"
discard <| go.run {} {}
erase := fun declName => do
let s := ext.getState (← getEnv)
let s := s.erase (declName)
modifyEnv fun env => ext.modifyState env fun _ => s
}


abbrev EggXtensionMap := HashMap Name EggXtension

initialize eggXtensionMapRef : IO.Ref EggXtensionMap ← IO.mkRef {}

def registerEggAttr (attrName : Name) (attrDescr : String)
(ref : Name := by exact decl_name%) : IO EggXtension := do
let ext ← mkEggXt ref
mkEggAttr attrName attrDescr ext ref -- Remark: it will fail if it is not performed during initialization
eggXtensionMapRef.modify fun map => map.insert attrName ext
return ext

initialize eggXtension : EggXtension ← registerEggAttr `egg "equality saturation theorem theorem"


syntax (name := showEgg) "#show_egg_set" : command

--
--
--#check Lean.Meta.mkSimpAttr
--
--@[command_elab insertEgg] def elabInsertEgg : CommandElab := fun stx => do
-- IO.println s!"inserting {stx[1].getId}"
-- modifyEnv fun env => eggXtension.addEntry env stx[1].getId
--
@[command_elab showEgg] def elabShowEgg : CommandElab := fun _ => do
logInfo m!"egg set: {eggXtension.getState (← getEnv) |>.toList}"
--
--
--initialize eggTag : TagAttribute ←
-- registerTagAttribute `egg "Tag for marking lemmas used for equality saturation" (validate := validateEggTheorem)

end Egg
1 change: 1 addition & 0 deletions Lean/Egg/Tests/Conditional.lean
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ info: [egg.rewrites] Rewrites
expr: [?l₂]
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down
71 changes: 71 additions & 0 deletions Lean/Egg/Tests/FreshmanTags.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import Egg

class Inv (α) where inv : α → α
postfix:max "⁻¹" => Inv.inv

class Zero (α) where zero : α
instance [Zero α] : OfNat α 0 where ofNat := Zero.zero

class One (α) where one : α
instance [One α] : OfNat α 1 where ofNat := One.one

class CommRing (α) extends Zero α, One α, Add α, Sub α, Mul α, Div α, Pow α Nat, Inv α, Neg α where
comm_add (a b : α) : a + b = b + a
comm_mul (a b : α) : a * b = b * a
add_assoc (a b c : α) : a + (b + c) = (a + b) + c
mul_assoc (a b c : α) : a * (b * c) = (a * b) * c
sub_canon (a b : α) : a - b = a + -b
neg_add (a : α) : a + -a = 0
div_canon (a b : α) : a / b = a * b⁻¹
zero_add (a : α) : a + 0 = a
zero_mul (a : α) : a * 0 = 0
one_mul (a : α) : a * 1 = a
distrib (a b c : α) : a * (b + c) = (a * b) + (a * c)
pow_zero (a : α) : a ^ 0 = 1
pow_one (a : α) : a ^ 1 = a
pow_two (a : α) : a ^ 2 = (a ^ 1) * a
pow_three (a : α) : a ^ 3 = (a ^ 2) * a

attribute [egg] CommRing.comm_add
attribute [egg] CommRing.comm_mul
attribute [egg] CommRing.add_assoc
attribute [egg] CommRing.mul_assoc
attribute [egg] CommRing.sub_canon
attribute [egg] CommRing.neg_add
attribute [egg] CommRing.div_canon
attribute [egg] CommRing.zero_add
attribute [egg] CommRing.zero_mul
attribute [egg] CommRing.one_mul
attribute [egg] CommRing.distrib
attribute [egg] CommRing.pow_zero
attribute [egg] CommRing.pow_one
attribute [egg] CommRing.pow_two
attribute [egg] CommRing.pow_three

class CharTwoRing (α) extends CommRing α where
char_two (a : α) : a + a = 0

variable [CharTwoRing α] (x y : α)

theorem freshmans_dream₂ : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by
egg calc (x + y) ^ 2
_ = (x + y) * (x + y)
_ = x * (x + y) + y * (x + y)
_ = x ^ 2 + x * y + y * x + y ^ 2
_ = x ^ 2 + y ^ 2 with [CharTwoRing.char_two]

theorem freshmans_dream₂' : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by
egg [CharTwoRing.char_two]

theorem freshmans_dream₃ : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by
egg calc [CharTwoRing.char_two] (x + y) ^ 3
_ = (x + y) * (x + y) * (x + y)
_ = (x + y) * (x * (x + y) + y * (x + y))
_ = (x + y) * (x ^ 2 + x * y + y * x + y ^ 2)
_ = (x + y) * (x ^ 2 + y ^ 2)
_ = x * (x ^ 2 + y ^ 2) + y * (x ^ 2 + y ^ 2)
_ = (x * x ^ 2) + x * y ^ 2 + y * x ^ 2 + y * y ^ 2
_ = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3

theorem freshmans_dream₃' : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by
egg [CharTwoRing.char_two]
2 changes: 2 additions & 0 deletions Lean/Egg/Tests/Prune.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ info: [egg.rewrites] Rewrites
expr: []
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down Expand Up @@ -55,6 +56,7 @@ info: [egg.rewrites] Rewrites
expr: [?n, ?m]
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down
1 change: 1 addition & 0 deletions Lean/Egg/Tests/TC Proj Binders.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ info: [egg.rewrites] Rewrites
expr: []
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (2)
[egg.rewrites] #0[0?69632,0](⇔)
[egg.rewrites] HMul.hMul = Mul.mul
Expand Down
Loading

0 comments on commit c25d9c9

Please sign in to comment.