Skip to content

Commit

Permalink
feat: handle (some) nested inductives in derive handlers.
Browse files Browse the repository at this point in the history
This code has broken other stuff, is badly documented, needs refactoring and cleaning.
  • Loading branch information
arthur-adjedj committed Jul 26, 2024
1 parent 2d69994 commit a00f7f2
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/FromToJson.lean
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let type ← Term.elabTerm header.targetType none
let body ← mkFromJsonBody ctx header type xs
--TODO: make function non-partial
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppN header.argNames) := $body:term)
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term)

def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do
let mut auxDefs := #[]
Expand Down
143 changes: 115 additions & 28 deletions src/Lean/Elab/Deriving/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Lean.Elab.Term
import Lean.Elab.Binders
import Lean.PrettyPrinter
import Lean.Data.Options
import Lean.Meta.CollectFVars

namespace Lean.Elab.Deriving
open Meta
Expand Down Expand Up @@ -80,10 +81,10 @@ partial instance : BEq NestedOccurence := ⟨go⟩
where go
| leaf ind₁,.leaf ind₂ => ind₁.name == ind₂.name
| node ind₁ arr₁,.node ind₂ arr₂ => Id.run do
unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do
unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do
return false
for i in [:arr₁.size] do
unless @instBEqSum _ _ ⟨go⟩ inferInstance|>.beq arr₁[i]! arr₂[i]! do
unless @instBEqSum _ _ ⟨go⟩ inferInstance |>.beq arr₁[i]! arr₂[i]! do
return false
return true
| _,_ => false
Expand Down Expand Up @@ -126,7 +127,7 @@ partial def toListofNests (e : NestedOccurence) : List NestedOccurence :=
e::l

/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/
partial def mkAppN (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do
partial def mkAppTerm (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do
go nestedOcc argNames
where
go (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do
Expand Down Expand Up @@ -159,16 +160,76 @@ where
args := args.push <| ←`($tm)
`(@$f $args*)

/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/
partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Subarray Expr) : TermElabM Expr := do
-- logInfo $ s!"mkAppExpr {nestedOcc} {argNames}"
let res ← go nestedOcc argNames
-- logInfo $ s!"mkAppExpr {nestedOcc} {argNames} = {res}"
return res
where
go (nestedOcc : NestedOccurence) (argNames : Array Expr): TermElabM Expr := do
-- logInfo $ s!"mkAppExpr.go {nestedOcc} {argNames}"
match nestedOcc with
| leaf indVal => do
let numArgs := indVal.numParams + indVal.numIndices
unless argNames.size >= numArgs do
throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}"
let mut args := Array.mkArray numArgs default
for i in [:numArgs] do
let arg := argNames[i]!
args := args.modify i (fun _ => arg)
-- logInfo $ args
let name ← Meta.mkConstWithFreshMVarLevels indVal.name
-- logInfo $ name
let res := args.foldl mkApp name
return res
| node indVal arr =>
let mut args := #[]
for nestedOcc? in arr do
match nestedOcc? with
| .inl occ =>
let arg ← go occ argNames
args := args.push arg
| .inr (.bvar i) =>
let some argName := argNames[argNames.size-i-1]?
| throwError s!"Cannot instantiate {nestedOcc} : not enough arguments given"
let id := argName
args := args.push id
| .inr e =>
args := args.push e
let res ← Meta.mkAppOptM indVal.name (args.map some)
return res

end NestedOccurence

structure NestedOccurence.Context where
indNames : List Name
res : List (Array Expr × Array Name × NestedOccurence)

abbrev NestedOccM := StateT NestedOccurence.Context TermElabM

def withIndNames (indNames : List Name) (f : NestedOccM Unit) : TermElabM NestedOccurence.Context := do
let ⟨(),ctx⟩ ← StateT.run f ⟨indNames,[]⟩
return ctx

def add_res (x : (Array Expr × Array Name × NestedOccurence)) : NestedOccM Unit := do
let ⟨names,res⟩ ← get
set (⟨names,x::res⟩ : NestedOccurence.Context)

def add_name (n : Name) : NestedOccM Unit := do
let ⟨names,res⟩ ← get
set (⟨n::names,res⟩ : NestedOccurence.Context)


partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Subarray Expr): MetaM (Option NestedOccurence) := do
-- logInfo $ s!"getNestedOccurencesOf {inds} {e} {fvars}"
let .inl occs ← go e | return none
return occs
where
go (e : Expr) : MetaM (NestedOccurence ⊕ Expr) := do
let hd := e.getAppFn
let fallback _ := pure <| .inr <| e.abstract fvars
let .const name .. := hd | fallback ()
let .const name _ := hd | fallback ()
if let some indName := inds.find? (· = name) then
let indVal ← getConstInfoInduct indName
return .inl <| .leaf indVal
Expand All @@ -177,49 +238,63 @@ where
let indVal ← getConstInfoInduct name
let args := e.getAppArgs
let args := args.map (·.abstract fvars)
let nestedOccsArgs ← args.mapM <| go
let nestedOccsArgs ← args.mapM go
if nestedOccsArgs.any Sum.isLeft then
return .inl <| .node indVal nestedOccsArgs
else fallback ()
catch _ => fallback ()

def getNestedOccurences (indNames : List Name) : TermElabM (List (Array Name × NestedOccurence)) := do
let l ← indNames.foldlM (fun l x => bind (go x) (return · ++ l)) []
-- We erase duplicates up to alpha-equivalence
return @List.eraseDups _ (⟨fun ⟨_,occ₁⟩ ⟨_,occ₂⟩=> BEq.beq occ₁ occ₂⟩) l
partial def getNestedOccurences (indNames : List Name) : TermElabM (List (Array Expr × Array Name × NestedOccurence)) := do
let ⟨_,l⟩ ← withIndNames indNames do
for name in indNames do
go name #[] #[]
return @List.eraseDups _ (⟨fun x y => x.2.2 == y.2.2⟩) l
where
go (indName : Name) := do
go (indName : Name) (args : Array Expr) (fvars : Array Expr): NestedOccM Unit := do
-- logInfo $ s!"getNestedOccurences.go {indName} {args} {fvars} {toString $ (← get).res.map (·.2.2)}"
let indVal ← getConstInfoInduct indName
if !indVal.isNested then
return []
let constrs ← indVal.ctors.mapM (getConstInfoCtor)
let mut res := []
if !indVal.isNested && args.size == 0 then
return
let constrs ← indVal.ctors.mapM getConstInfoCtor
for constInfo in constrs do
let constList ← forallTelescope constInfo.type fun xs _ => do
-- logInfo $ s!"constInfo: {constInfo.name}"
let instConstInfo ← forallBoundedTelescope constInfo.type args.size fun xs e =>
return e.abstract xs |>.instantiate args
-- logInfo $ s!"instConstInfo: {instConstInfo}"
forallTelescope instConstInfo fun xs _ => do
-- logInfo $ s!"xs : {xs}"
-- logInfo $ s!"numParams : {constInfo.numParams}"
let mut paramArgs := #[]
let mut l := []
for i in [:constInfo.numParams] do
let e := xs[i]!
let some e := xs[i]? | break
let localDecl ← e.fvarId!.getDecl
let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes
paramArgs := paramArgs.push paramName
let mut localArgs := #[]
for i in [constInfo.numParams:xs.size] do
for i in [:xs.size] do
let e := xs[i]!
let ty ← e.fvarId!.getType
let localDecl ← e.fvarId!.getDecl
let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes
let occs ← getNestedOccurencesOf indNames ty xs[:i]
let l' := if let .some x := occs then x.toListofNests else []
let l' := l'.map fun occ =>
-- logInfo $ s!"l' : {l'}"
for occ in l' do
let relevantLocalArgs := localArgs.filter (occ.containsFVar ⟨·⟩)
let args := paramArgs ++ relevantLocalArgs
(args,occ)
let new_args := paramArgs ++ relevantLocalArgs
add_res (xs[:i].toArray,new_args,occ)
let fvars := fvars ++ xs[:i]
-- logInfo $ s!"fvars : {fvars}"
let app ← occ.mkAppExpr fvars.toSubarray
let hd := app.getAppFn.constName!
if hd ∈ (← get).indNames then
continue
let args := app.getAppArgs
add_name hd
go hd args fvars
l := l ++ l'
localArgs := localArgs.push paramName
pure l
res := res ++ constList
return res

def indNameToFunName (indName : Name) : String :=
match indName.eraseMacroScopes with
Expand Down Expand Up @@ -281,10 +356,10 @@ def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do
for indName in indNames do
let indVal ← getConstInfoInduct indName
let args ← mkInductArgNames indVal
typeInfos' := (args,.leaf indVal)::typeInfos'
typeInfos' := (#[],args,.leaf indVal)::typeInfos'
typeInfos' := (← getNestedOccurences indVal.all) ++ typeInfos'
let typeArgNames := typeInfos'.map Prod.fst |>.toArray
let typeInfos := typeInfos'.map Prod.snd |>.toArray
let typeArgNames := typeInfos'.map (Prod.fst ∘ Prod.snd) |>.toArray
let typeInfos := typeInfos'.map (Prod.snd ∘ Prod.snd) |>.toArray
trace[Elab.Deriving] s!"typeInfos : {typeInfos}\nargNames : {typeArgNames}"
let auxFunNames ← typeInfos.mapM <|
fun occ => do return ← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ mkInstName occ
Expand All @@ -310,7 +385,7 @@ def mkInstanceCmds (ctx : Context) (className : Name) (useAnonCtor := true) : Te
let argNames := ctx.typeArgNames[i]!
let binders ← mkImplicitBinders argNames
let binders := binders ++ (← mkInstImplicitBinders className nestedOcc argNames)
let indType ← nestedOcc.mkAppN argNames
let indType ← nestedOcc.mkAppTerm argNames
let type ← `($(mkCIdent className) $indType)
let mut val := mkIdent auxFunName
if useAnonCtor then
Expand All @@ -331,7 +406,7 @@ structure Header where
open TSyntax.Compat in
def mkHeader (className : Name) (arity : Nat) (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do
let mut binders ← mkImplicitBinders argNames
let targetType ← nestedOcc.mkAppN argNames
let targetType ← nestedOcc.mkAppTerm argNames
let mut targetNames := #[]
for _ in [:arity] do
targetNames := targetNames.push (← mkFreshUserName `x)
Expand All @@ -357,4 +432,16 @@ def Context.getFunName? (ctx : Context) (header : Header) (ty : Expr) (xs : Arra
let recField := indValNum.map (ctx.auxFunNames[·]!)
return recField

-- set_option trace.Elab.Deriving true

-- #eval show TermElabM Unit from do
-- try
-- let l ← getNestedOccurences [``Tree]
-- -- logInfo $ "Occurences:"
-- for x in l do
-- -- logInfo $ toString x.2.2
-- catch | e => -- logInfo $ ← e.toMessageData.format



end Lean.Elab.Deriving
8 changes: 8 additions & 0 deletions tests/lean/decEqMutualInductives.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ inductive ComplexInductive (A B C : Type) (n : Nat) : Type
inductive NestedComplex (A C : Type) : Type
| constr : ComplexInductive A (NestedComplex A C) C 1 → NestedComplex A C
deriving DecidableEq

namespace nested

inductive Tree (α : Type) where
| node : Array (Tree α) → Tree α
deriving DecidableEq

end nested

0 comments on commit a00f7f2

Please sign in to comment.