diff --git a/Lean/Egg.lean b/Lean/Egg.lean index e9dcad22d7..3ad813d3fe 100644 --- a/Lean/Egg.lean +++ b/Lean/Egg.lean @@ -32,3 +32,4 @@ import Egg.Tactic.Basic import Egg.Tactic.Calc import Egg.Tactic.Guides import Egg.Tactic.Trace +import Egg.Tactic.Tags diff --git a/Lean/Egg/Core/Config.lean b/Lean/Egg/Core/Config.lean index db516c688e..f3524c42d7 100644 --- a/Lean/Egg/Core/Config.lean +++ b/Lean/Egg/Core/Config.lean @@ -17,6 +17,7 @@ structure Encoding extends Normalization where structure Gen where builtins := true + tagged? := some `egg genTcProjRws := true genTcSpecRws := true genGoalTcSpec := true diff --git a/Lean/Egg/Core/Explanation/Parse.lean b/Lean/Egg/Core/Explanation/Parse.lean index e8a7d6f1b3..2a347dceb4 100644 --- a/Lean/Egg/Core/Explanation/Parse.lean +++ b/Lean/Egg/Core/Explanation/Parse.lean @@ -36,6 +36,7 @@ syntax "*" noWs num : egg_basic_fwd_rw_src syntax "⊢" : egg_basic_fwd_rw_src syntax "↣" noWs num : egg_basic_fwd_rw_src syntax "◯" noWs num : egg_basic_fwd_rw_src +syntax "□" noWs num (noWs "/" noWs num)? : egg_basic_fwd_rw_src syntax "[" egg_tc_proj_loc num "," num "]" : egg_tc_proj @@ -119,6 +120,7 @@ private def parseTcProjLocation : (TSyntax `egg_tc_proj_loc) → Source.TcProjLo private def parseBasicFwdRwSrc : (TSyntax `egg_basic_fwd_rw_src) → Source | `(egg_basic_fwd_rw_src|#$idx$[/$eqn?]?) => .explicit idx.getNat (eqn?.map TSyntax.getNat) + | `(egg_basic_fwd_rw_src|□$idx$[/$eqn?]?) => .tagged idx.getNat (eqn?.map TSyntax.getNat) | `(egg_basic_fwd_rw_src|*$idx) => .star (.fromUniqueIdx idx.getNat) | `(egg_basic_fwd_rw_src|⊢) => .goal | `(egg_basic_fwd_rw_src|↣$idx) => .guide idx.getNat diff --git a/Lean/Egg/Tactic/Premises/Gen.lean b/Lean/Egg/Tactic/Premises/Gen.lean index 95c5faea22..0da7f63d80 100644 --- a/Lean/Egg/Tactic/Premises/Gen.lean +++ b/Lean/Egg/Tactic/Premises/Gen.lean @@ -12,11 +12,12 @@ namespace Egg.Premises -- TODO: Perform pruning during generation, not after. private def tracePremises - (basic : WithSyntax Rewrites) (builtins tc pruned : Rewrites) (facts : WithSyntax Facts) + (basic : WithSyntax Rewrites) (tagged builtins tc pruned : Rewrites) (facts : WithSyntax Facts) (cfg : Config.Gen) : TacticM Unit := do let cls := `egg.rewrites withTraceNode cls (fun _ => return "Rewrites") do withTraceNode cls (fun _ => return m!"Basic ({basic.elems.size})") do basic.elems.trace basic.stxs cls + withTraceNode cls (fun _ => return m!"Tagged ({tagged.size})") do tagged.trace #[] cls withTraceNode cls (fun _ => return m!"Generated ({tc.size})") do tc.trace #[] cls withTraceNode cls (fun _ => return m!"Builtin ({builtins.size})") do builtins.trace #[] cls withTraceNode cls (fun _ => return m!"Hypotheses ({facts.elems.size})") do @@ -30,14 +31,15 @@ private def tracePremises partial def gen (goal : Congr) (ps : TSyntax `egg_premises) (guides : Guides) (cfg : Config) (amb : MVars.Ambient) : TacticM (Rewrites × Facts) := do + let tagged ← Premises.buildTagged cfg amb let ⟨⟨basic, basicStxs⟩, facts⟩ ← Premises.elab { norm? := cfg, amb } ps - let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := #[]) + let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := tagged) let builtins ← if cfg.builtins then Rewrites.builtins { norm? := cfg, amb } else pure #[] - let (builtins, _, pruned₂) ← prune builtins (remove := basic) + let (builtins, _, pruned₂) ← prune builtins (remove := tagged ++ basic) let tc ← genTcRws (basic ++ builtins) facts.elems - let (tc, _, pruned₃) ← prune tc (remove := basic ++ builtins) - tracePremises ⟨basic, basicStxs⟩ builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg - let rws := basic ++ builtins ++ tc + let (tc, _, pruned₃) ← prune tc (remove := tagged ++ basic ++ builtins) + tracePremises ⟨basic, basicStxs⟩ tagged builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg + let rws := tagged ++ basic ++ builtins ++ tc catchInvalidConditionals rws return (rws, facts.elems) where diff --git a/Lean/Egg/Tactic/Premises/Parse.lean b/Lean/Egg/Tactic/Premises/Parse.lean index fd0a2bd254..89e7129cfb 100644 --- a/Lean/Egg/Tactic/Premises/Parse.lean +++ b/Lean/Egg/Tactic/Premises/Parse.lean @@ -1,5 +1,7 @@ import Egg.Core.Premise.Rewrites import Egg.Core.Premise.Facts +import Egg.Tactic.Premises.Validate +import Egg.Tactic.Tags import Lean open Lean Meta Elab Tactic @@ -117,6 +119,19 @@ private def Premises.taggedRw (prem : Name) (idx : Nat) (cfg : Rewrite.Config) : let rws ← Premises.explicit ident idx mk .tagged return rws.elems +private def Premises.elabTagged (prems : Array Name) (cfg : Rewrite.Config) : TacticM Rewrites := do + let mut rws : Rewrites := #[] + for prem in prems, idx in [:prems.size] do + rws := rws ++ (← taggedRw prem idx cfg) + return rws + +def Premises.buildTagged (cfg : Config) (amb : MVars.Ambient ): TacticM Rewrites := + match cfg.tagged? with + | none => return {} + | some _ => do -- This should later use this `Name` to find the proper extension + let prems := eggXtension.getState (← getEnv) + elabTagged prems { norm? := cfg, amb} + -- Note: This function is expected to be called with the lctx which contains the desired premises. -- -- Note: We need to filter out auxiliary declaration and implementation details, as they are not diff --git a/Lean/Egg/Tactic/Premises/Validate.lean b/Lean/Egg/Tactic/Premises/Validate.lean new file mode 100644 index 0000000000..9deca2a111 --- /dev/null +++ b/Lean/Egg/Tactic/Premises/Validate.lean @@ -0,0 +1,53 @@ +import Egg.Core.Premise.Rewrites + +open Lean Meta Elab Tactic + +inductive Premise.Raw where + | single (expr : Expr) (type? : Option Expr := none) + | eqns (exprs : Array Expr) + +inductive Premise.RawRaw where + | eqns (exprs : Array Name) + | single (expr : Expr) (type? : Option Expr := none) + | prem (prem : Term) + + +def Premise.Raw.validate (prem : Term) : MetaM Premise.RawRaw := do + if let some const ← optional (resolveGlobalConstNoOverload prem) then + if let some eqs ← getEqnsFor? const (nonRec := true) then + -- `prem` is a global definition. + return .eqns eqs + else + -- `prem` is an global constant which is not a definition with equations. + let env ← getEnv + let some info := env.find? const | throwErrorAt prem m!"unknown constant '{mkConst const}'" + match info with + | .defnInfo _ | .axiomInfo _ | .thmInfo _ | .opaqueInfo _ => + let lvlMVars ← List.replicateM info.numLevelParams mkFreshLevelMVar + let val := if info.hasValue then info.instantiateValueLevelParams! lvlMVars else .const info.name lvlMVars + let type := info.instantiateTypeLevelParams lvlMVars + return .single val type + | _ => throwErrorAt prem "egg requires arguments to be theorems, definitions or axioms" + else + -- `prem` is an invalid identifier or a term which is not an identifier. + -- We must use `Tactic.elabTerm`, not `Term.elabTerm`. Otherwise elaborating `‹...›` doesn't + -- work correctly. See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Elaborate.20.E2.80.B9.2E.2E.2E.E2.80.BA + return .prem prem + +-- We don't just elaborate premises directly as: +-- (1) this can cause problems for global constants with typeclass arguments, as Lean sometimes +-- tries to synthesize the arguments and fails if it can't (instead of inserting mvars). +-- (2) global constants which are definitions with equations (cf. `getEqnsFor?`) are supposed to +-- be replaced by their defining equations. +partial def Premise.Raw.elab (prem : Term) : TacticM Premise.Raw := do + if let some hyp ← optional (getFVarId prem) then + -- `prem` is a local declaration. + let decl ← hyp.getDecl + if decl.isImplementationDetail || decl.isAuxDecl then + throwErrorAt prem "egg does not support using auxiliary declarations" + else + return .single (.fvar hyp) (← hyp.getType) + match (← validate prem) with + | .eqns eqs => return .eqns <| ← eqs.mapM fun eqn => Tactic.elabTerm (mkIdent eqn) none + | .single val type? => return .single val type? + | .prem prem => return .single (← Tactic.elabTerm prem none) diff --git a/Lean/Egg/Tactic/Tags.lean b/Lean/Egg/Tactic/Tags.lean new file mode 100644 index 0000000000..b85e6b9ca2 --- /dev/null +++ b/Lean/Egg/Tactic/Tags.lean @@ -0,0 +1,110 @@ +import Egg.Tactic.Premises.Validate +import Lean + +open Lean Elab Tactic Term + +namespace Egg + +/-- +This validates that a theorem can be used by the `egg` tactic (it ultimately boils down to an equality.) + +Unimplemented: Currently, this is a noop. +-/ +private def validateEggTheorem (thm : Term) : MetaM Unit := do + let _ ← Premise.Raw.validate thm + return () + +-- Ideally this should be at some point a discrimination tree +abbrev EggTheorems := Array Name + +abbrev EggEntry := Name -- later: Lean.Meta.SimpEntry + +def addEggTheoremEntry (d : EggTheorems) (e : EggEntry) : EggTheorems := + d.push e + +abbrev EggXtension := SimpleScopedEnvExtension EggEntry EggTheorems + +open Lean.Elab +open Lean.Elab.Command + +def EggXtension.getTheorems (ext : EggXtension) : CoreM EggTheorems := + return ext.getState (← getEnv) + +/-- +This function does the appropriate preprocessing from a `Name` to record a theorem as +an `egg` theorem. + +For now, this preprocessing is nothing (just store the name in a singleton `Array`). +However, in the future this can be used like simp to do more efficient preprocessing +and deal with other kinds of `egg` lemmas (or even import `simp` lemmas). +-/ +private def mkEggTheoremsFromConst (declName : Name) : MetaM EggTheorems := + pure #[declName] + +def addEggTheorem (ext : EggXtension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do + let _ ← validateEggTheorem { raw := Syntax.ident default default declName []} -- ugly! + let eggThms ← mkEggTheoremsFromConst declName + for eggThm in eggThms do + ext.add eggThm attrKind + +def mkEggXt (name : Name := by exact decl_name%) : IO EggXtension := + registerSimpleScopedEnvExtension { + name := name + initial := {} + addEntry := fun d e => addEggTheoremEntry d e + } + +def mkEggAttr (attrName : Name) (attrDescr : String) (ext : EggXtension) + (ref : Name := by exact decl_name%) : IO Unit := + registerBuiltinAttribute { + ref := ref + name := attrName + descr := attrDescr + applicationTime := AttributeApplicationTime.afterCompilation + add := fun declName _stx attrKind => do + let go : MetaM Unit := do + let info ← getConstInfo declName + if (← Meta.isProp info.type) then + addEggTheorem ext declName attrKind + else + throwError "invalid 'egg', it is not a proposition" + discard <| go.run {} {} + erase := fun declName => do + let s := ext.getState (← getEnv) + let s := s.erase (declName) + modifyEnv fun env => ext.modifyState env fun _ => s + } + + +abbrev EggXtensionMap := HashMap Name EggXtension + +initialize eggXtensionMapRef : IO.Ref EggXtensionMap ← IO.mkRef {} + +def registerEggAttr (attrName : Name) (attrDescr : String) + (ref : Name := by exact decl_name%) : IO EggXtension := do + let ext ← mkEggXt ref + mkEggAttr attrName attrDescr ext ref -- Remark: it will fail if it is not performed during initialization + eggXtensionMapRef.modify fun map => map.insert attrName ext + return ext + +initialize eggXtension : EggXtension ← registerEggAttr `egg "equality saturation theorem theorem" + + +syntax (name := showEgg) "#show_egg_set" : command + +-- +-- +--#check Lean.Meta.mkSimpAttr +-- +--@[command_elab insertEgg] def elabInsertEgg : CommandElab := fun stx => do +-- IO.println s!"inserting {stx[1].getId}" +-- modifyEnv fun env => eggXtension.addEntry env stx[1].getId +-- +@[command_elab showEgg] def elabShowEgg : CommandElab := fun _ => do + logInfo m!"egg set: {eggXtension.getState (← getEnv) |>.toList}" +-- +-- +--initialize eggTag : TagAttribute ← +-- registerTagAttribute `egg "Tag for marking lemmas used for equality saturation" (validate := validateEggTheorem) + +end Egg diff --git a/Lean/Egg/Tests/Conditional.lean b/Lean/Egg/Tests/Conditional.lean index 888dde74a1..7b213226c5 100644 --- a/Lean/Egg/Tests/Conditional.lean +++ b/Lean/Egg/Tests/Conditional.lean @@ -95,6 +95,7 @@ info: [egg.rewrites] Rewrites expr: [?l₂] class: [] level: [] + [egg.rewrites] Tagged (0) [egg.rewrites] Generated (0) [egg.rewrites] Builtin (0) [egg.rewrites] Hypotheses (0) diff --git a/Lean/Egg/Tests/FreshmanTags.lean b/Lean/Egg/Tests/FreshmanTags.lean new file mode 100644 index 0000000000..45d3863567 --- /dev/null +++ b/Lean/Egg/Tests/FreshmanTags.lean @@ -0,0 +1,71 @@ +import Egg + +class Inv (α) where inv : α → α +postfix:max "⁻¹" => Inv.inv + +class Zero (α) where zero : α +instance [Zero α] : OfNat α 0 where ofNat := Zero.zero + +class One (α) where one : α +instance [One α] : OfNat α 1 where ofNat := One.one + +class CommRing (α) extends Zero α, One α, Add α, Sub α, Mul α, Div α, Pow α Nat, Inv α, Neg α where + comm_add (a b : α) : a + b = b + a + comm_mul (a b : α) : a * b = b * a + add_assoc (a b c : α) : a + (b + c) = (a + b) + c + mul_assoc (a b c : α) : a * (b * c) = (a * b) * c + sub_canon (a b : α) : a - b = a + -b + neg_add (a : α) : a + -a = 0 + div_canon (a b : α) : a / b = a * b⁻¹ + zero_add (a : α) : a + 0 = a + zero_mul (a : α) : a * 0 = 0 + one_mul (a : α) : a * 1 = a + distrib (a b c : α) : a * (b + c) = (a * b) + (a * c) + pow_zero (a : α) : a ^ 0 = 1 + pow_one (a : α) : a ^ 1 = a + pow_two (a : α) : a ^ 2 = (a ^ 1) * a + pow_three (a : α) : a ^ 3 = (a ^ 2) * a + +attribute [egg] CommRing.comm_add +attribute [egg] CommRing.comm_mul +attribute [egg] CommRing.add_assoc +attribute [egg] CommRing.mul_assoc +attribute [egg] CommRing.sub_canon +attribute [egg] CommRing.neg_add +attribute [egg] CommRing.div_canon +attribute [egg] CommRing.zero_add +attribute [egg] CommRing.zero_mul +attribute [egg] CommRing.one_mul +attribute [egg] CommRing.distrib +attribute [egg] CommRing.pow_zero +attribute [egg] CommRing.pow_one +attribute [egg] CommRing.pow_two +attribute [egg] CommRing.pow_three + +class CharTwoRing (α) extends CommRing α where + char_two (a : α) : a + a = 0 + +variable [CharTwoRing α] (x y : α) + +theorem freshmans_dream₂ : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by + egg calc (x + y) ^ 2 + _ = (x + y) * (x + y) + _ = x * (x + y) + y * (x + y) + _ = x ^ 2 + x * y + y * x + y ^ 2 + _ = x ^ 2 + y ^ 2 with [CharTwoRing.char_two] + +theorem freshmans_dream₂' : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by + egg [CharTwoRing.char_two] + +theorem freshmans_dream₃ : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by + egg calc [CharTwoRing.char_two] (x + y) ^ 3 + _ = (x + y) * (x + y) * (x + y) + _ = (x + y) * (x * (x + y) + y * (x + y)) + _ = (x + y) * (x ^ 2 + x * y + y * x + y ^ 2) + _ = (x + y) * (x ^ 2 + y ^ 2) + _ = x * (x ^ 2 + y ^ 2) + y * (x ^ 2 + y ^ 2) + _ = (x * x ^ 2) + x * y ^ 2 + y * x ^ 2 + y * y ^ 2 + _ = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 + +theorem freshmans_dream₃' : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by + egg [CharTwoRing.char_two] diff --git a/Lean/Egg/Tests/Prune.lean b/Lean/Egg/Tests/Prune.lean index b991ce5369..2dfbde90cb 100644 --- a/Lean/Egg/Tests/Prune.lean +++ b/Lean/Egg/Tests/Prune.lean @@ -21,6 +21,7 @@ info: [egg.rewrites] Rewrites expr: [] class: [] level: [] + [egg.rewrites] Tagged (0) [egg.rewrites] Generated (0) [egg.rewrites] Builtin (0) [egg.rewrites] Hypotheses (0) @@ -55,6 +56,7 @@ info: [egg.rewrites] Rewrites expr: [?n, ?m] class: [] level: [] + [egg.rewrites] Tagged (0) [egg.rewrites] Generated (0) [egg.rewrites] Builtin (0) [egg.rewrites] Hypotheses (0) diff --git a/Lean/Egg/Tests/TC Proj Binders.lean b/Lean/Egg/Tests/TC Proj Binders.lean index ade179f82a..0e6badc36b 100644 --- a/Lean/Egg/Tests/TC Proj Binders.lean +++ b/Lean/Egg/Tests/TC Proj Binders.lean @@ -25,6 +25,7 @@ info: [egg.rewrites] Rewrites expr: [] class: [] level: [] + [egg.rewrites] Tagged (0) [egg.rewrites] Generated (2) [egg.rewrites] #0[0?69632,0](⇔) [egg.rewrites] HMul.hMul = Mul.mul diff --git a/Lean/Egg/Tests/Tags.lean b/Lean/Egg/Tests/Tags.lean new file mode 100644 index 0000000000..dc154d1e61 --- /dev/null +++ b/Lean/Egg/Tests/Tags.lean @@ -0,0 +1,44 @@ +import Egg + +class One (α) where one : α +instance [One α] : OfNat α 1 where ofNat := One.one + +class Inv (α) where inv : α → α +postfix:max "⁻¹" => Inv.inv + +class Group (α) extends Mul α, One α, Inv α where + mul_assoc (a b c : α) : (a * b) * c = a * (b * c) + one_mul (a : α) : 1 * a = a + mul_one (a : α) : a * 1 = a + inv_mul_self (a : α) : a⁻¹ * a = 1 + mul_inv_self (a : α) : a * a⁻¹ = 1 + +variable [Group α] (a b x y : α) + +attribute [egg] Group.mul_assoc +attribute [egg] Group.one_mul +attribute [egg] Group.mul_one +attribute [egg] Group.inv_mul_self +attribute [egg] Group.mul_inv_self + +/-- info: egg set: [Group.mul_assoc, Group.one_mul, Group.mul_one, Group.inv_mul_self, Group.mul_inv_self] -/ + #guard_msgs(info) in +#show_egg_set + +@[egg] +theorem inv_mul_cancel_left : a⁻¹ * (a * b) = b := by + egg + +/-- info: egg set: [Group.mul_assoc, Group.one_mul, Group.mul_one, Group.inv_mul_self, Group.mul_inv_self, inv_mul_cancel_left] -/ + #guard_msgs(info) in +#show_egg_set + +@[egg] +theorem mul_inv_cancel_left : a * (a⁻¹ * b) = b := by + egg + +-- TODO: should egg (or a version of egg) do the intros automatically to get an eq goal? +@[egg] +theorem mul_eq_de_eq_inv_mul : x = a⁻¹ * y → a * x = y := by + intros h + egg [h] diff --git a/Lean/Egg/Tests/WIP_ValidateTagged.lean b/Lean/Egg/Tests/WIP_ValidateTagged.lean new file mode 100644 index 0000000000..3b4ecfc9ca --- /dev/null +++ b/Lean/Egg/Tests/WIP_ValidateTagged.lean @@ -0,0 +1,54 @@ +import Egg + +class One (α) where one : α +instance [One α] : OfNat α 1 where ofNat := One.one + +class Inv (α) where inv : α → α +postfix:max "⁻¹" => Inv.inv + +class Group (α) extends Mul α, One α, Inv α where + mul_assoc (a b c : α) : (a * b) * c = a * (b * c) + one_mul (a : α) : 1 * a = a + mul_one (a : α) : a * 1 = a + inv_mul_self (a : α) : a⁻¹ * a = 1 + mul_inv_self (a : α) : a * a⁻¹ = 1 + +variable [Group α] (a b x y : α) + +attribute [egg] Group.mul_assoc +attribute [egg] Group.one_mul +attribute [egg] Group.mul_one +attribute [egg] Group.inv_mul_self +attribute [egg] Group.mul_inv_self + +def hPow : α → Nat → α + | _, 0 => 1 + | a, (n+1) => a * hPow a n + +instance [Group α] : HPow α Nat α := ⟨hPow⟩ + +def OrderN (n : Nat) (a : α) : Prop := a^n = 1 + +-- not defining the cardinality here for space reasons +def card (α) [Group α] : Nat := sorry + +-- This one should not go through! not an equality +@[egg] +theorem ex_min_order : ∃ n : Nat, OrderN n a ∧ (∀ n', n' < n → ¬ OrderN n a) := sorry + +-- This should also be recognized as an equality +@[egg] +theorem card_order : OrderN (card α) a := by + sorry + +def Abelian (α) [Group α] : Prop := ∀ a b : α, a * b = b * a + +def commutator := a*b*a⁻¹*b⁻¹ + +-- Ideally, egg can see through this prop that there's an equality? +@[egg] +theorem all_commutators_trivial_abelian : (∀ a b : α, commutator a b = 1) → Abelian α := by sorry + +/-- We should not break egg after adding these lemmas -/ +example : a * b = b * a := by + egg