Skip to content

Commit

Permalink
support for uncurrying function on meta level using something else th…
Browse files Browse the repository at this point in the history
…an default Prod
  • Loading branch information
lecopivo committed Jul 28, 2023
1 parent 7c8d5ae commit 7d99812
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions SciLean/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def mkAppFoldlM (const : Name) (xs : Array Expr) : MetaM Expr := do
/--
For `#[x₁, .., xₙ]` create `(x₁, .., xₙ)`.
-/
def mkProdElem (xs : Array Expr) : MetaM Expr := mkAppFoldrM ``Prod.mk xs

def mkProdFst (x : Expr) : MetaM Expr := mkAppM ``Prod.fst #[x]
def mkProdSnd (x : Expr) : MetaM Expr := mkAppM ``Prod.snd #[x]
def mkProdElem (xs : Array Expr) (mk := ``Prod.mk) : MetaM Expr := mkAppFoldrM mk xs

/--
For `(x₀, .., xₙ₋₁)` return `xᵢ` but as a product projection.
Expand All @@ -128,38 +125,38 @@ For example for `xyz : X × Y × Z`
- `mkProdProj xyz 1 3` returns `xyz.snd.fst`.
- `mkProdProj xyz 1 2` returns `xyz.snd`.
-/
def mkProdProj (x : Expr) (i : Nat) (n : Nat) : MetaM Expr := do
def mkProdProj (x : Expr) (i : Nat) (n : Nat) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM Expr := do
let X ← inferType x
if X.isAppOfArity ``Prod 2 then
match i, n with
| _, 0 => pure x
| _, 1 => pure x
| 0, _ => mkAppM ``Prod.fst #[x]
| i'+1, n'+1 => mkProdProj (← mkAppM ``Prod.snd #[x]) i' n'
| 0, _ => mkAppM fst #[x]
| i'+1, n'+1 => mkProdProj (← mkAppM snd #[x]) i' n'
else
if i = 0 then
return x
else
throwError "Failed `mkProdProj`, can't take {i}-th element of {← ppExpr x}. It has type {← ppExpr X} which is not a product type!"


def mkProdSplitElem (xs : Expr) (n : Nat) : MetaM (Array Expr) :=
def mkProdSplitElem (xs : Expr) (n : Nat) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM (Array Expr) :=
(Array.mkArray n 0)
|>.mapIdx (λ i _ => i.1)
|>.mapM (λ i => mkProdProj xs i n)
|>.mapM (λ i => mkProdProj xs i n fst snd)

def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do
def mkUncurryFun (n : Nat) (f : Expr) (mk := ``Prod.mk) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM Expr := do
if n ≤ 1 then
return f
forallTelescope (← inferType f) λ xs _ => do
let xs := xs[0:n]

let xProdName : String ← xs.foldlM (init:="") λ n x =>
do return (n ++ toString (← x.fvarId!.getUserName).eraseMacroScopes)
let xProdType ← inferType (← mkProdElem xs)
let xProdType ← inferType (← mkProdElem xs mk)

withLocalDecl xProdName default xProdType λ xProd => do
let xs' ← mkProdSplitElem xProd n
let xs' ← mkProdSplitElem xProd n fst snd
mkLambdaFVars #[xProd] (← mkAppM' f xs').headBeta


Expand All @@ -170,7 +167,7 @@ def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do
fun x => f x + c ==> (fun y => y + c) ∘ f
fun x => f x + g x ==> (fun (y₁,y₂) => y₁ + y₂) ∘ (fun x => (f x, g x))
-/
def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do
def splitLambdaToComp (e : Expr) (mk := ``Prod.mk) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM (Expr × Expr) := do
match e with
| .lam name type b bi =>
withLocalDecl name bi type fun x => do
Expand All @@ -197,11 +194,11 @@ def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do
else
f := f.app y

let y' ← mkProdElem ys'
let y' ← mkProdElem ys' mk
let g ← mkLambdaFVars #[.fvar xId] y'

f ← withLCtx lctx instances (mkLambdaFVars zs f)
f ← mkUncurryFun zs.size f
f ← mkUncurryFun zs.size f mk fst snd

return (f, g)

Expand Down

0 comments on commit 7d99812

Please sign in to comment.