Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into broken_alive_autogene…
Browse files Browse the repository at this point in the history
…rated
  • Loading branch information
bollu committed Apr 18, 2024
2 parents 5e7e64d + 0e05c32 commit 38574a1
Show file tree
Hide file tree
Showing 18 changed files with 490 additions and 127 deletions.
3 changes: 0 additions & 3 deletions SSA/Core/Framework.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ import SSA.Core.HVector
import Mathlib.Data.List.AList
import Mathlib.Data.Finset.Piecewise

import SSA.Projects.MLIRSyntax.AST -- TODO post-merge: bring into Core
import SSA.Projects.MLIRSyntax.EDSL -- TODO post-merge: bring into Core

open Ctxt (Var VarSet Valuation)
open TyDenote (toType)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import SSA.Core.Util.ConcreteOrMVar
open Lean PrettyPrinter

/-!
# MLIR Syntax AST
This file contains the AST datastructures for generic MLIR syntax
-/

namespace MLIR.AST

-- variable (MCtxt : Type*)
Expand Down
63 changes: 63 additions & 0 deletions SSA/Core/MLIRSyntax/EDSL.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import SSA.Core.MLIRSyntax.GenericParser
import SSA.Core.MLIRSyntax.Transform

/-!
# MLIR Dialect Domain Specific Language
This file sets up generic glue meta-code to tie together the generic MLIR parser with the
`Transform` mechanism, to obtain an easy way to specify a DSL that elaborates into `Com`/`Expr`
instances for a specific dialect.
-/

namespace SSA

open Qq Lean Meta Elab Term
open MLIR.AST

/--
`elabIntoCom` is a building block for defining a dialect-specific DSL based on the geneeric MLIR
syntax parser.
For example, if `FooOp` is the type of operations of a "Foo" dialect, we can build a term elaborator
for this dialect as follows:
```
elab "[foo_com| " reg:mlir_region "]" : term => SSA.elabIntoCom reg q(FooOp)
-- ^^^^^^^ ^^^^^
```
-/
def elabIntoCom (region : TSyntax `mlir_region) (Op : Q(Type)) {Ty : Q(Type)}
(_opSignature : Q(OpSignature $Op $Ty) := by exact q(by infer_instance))
(φ : Q(Nat) := q(0))
(_transformTy : Q(TransformTy $Op $Ty $φ) := by exact q(by infer_instance))
(_transformExpr : Q(TransformExpr $Op $Ty $φ) := by exact q(by infer_instance))
(_transformReturn : Q(TransformReturn $Op $Ty $φ) := by exact q(by infer_instance))
:
TermElabM Expr := do
let ast_stx ← `([mlir_region| $region])
let ast ← elabTermEnsuringTypeQ ast_stx q(Region $φ)
let com : Q(MLIR.AST.ExceptM $Op (Σ (Γ' : Ctxt $Ty) (ty : $Ty), Com $Op Γ' ty)) :=
q(MLIR.AST.mkCom $ast)
synthesizeSyntheticMVarsNoPostponing
/- Now reduce the term. We do this so that the resulting term will be of the form
`Com.lete _ <| Com.lete _ <| ... <| Com.ret _`,
rather than still containing the `Transform` machinery applied to a raw AST.
This has the side-effect of also fully reducing the expressions involved.
We reduce with mode `.default` so that a dialect can prevent reduction of specific parts
by marking those `irreducible` -/
let com : Q(MLIR.AST.ExceptM $Op (Σ (Γ' : Ctxt $Ty) (ty : $Ty), Com $Op Γ' ty)) ←
withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do
withTransparency (mode := .default) <|
return ←reduce com
let comExpr : Expr := com
trace[Meta] com
trace[Meta] comExpr

match comExpr.app3? ``Except.ok with
| .some (_εexpr, _αexpr, aexpr) =>
match aexpr.app4? ``Sigma.mk with
| .some (_αexpr, _βexpr, _fstexpr, sndexpr) =>
match sndexpr.app4? ``Sigma.mk with
| .some (_αexpr, _βexpr, _fstexpr, sndexpr) =>
return sndexpr
| .none => throwError "Found `Except.ok (Sigma.mk _ WRONG)`, Expected (Except.ok (Sigma.mk _ (Sigma.mk _ _))"
| .none => throwError "Found `Except.ok WRONG`, Expected (Except.ok (Sigma.mk _ _))"
| .none => throwError "Expected `Except.ok`, found {comExpr}"
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@ import Lean.PrettyPrinter
import Lean.PrettyPrinter.Formatter
import Lean.Parser
import Lean.Parser.Extra
import SSA.Projects.MLIRSyntax.AST
import SSA.Core.MLIRSyntax.AST

/-!
# MLIR Syntax Parsing
This file uses Lean's parser extensions to parse generic MLIR syntax into datastructures defined
in `MLIRSyntax.AST`.
Key definitions are the `[mlir_op| ...]` and `[mlir_region| ...]` term elaborators
-/

open Lean
open Lean.Parser
Expand Down Expand Up @@ -102,7 +111,7 @@ partial def balancedBracketsFnAux (startPos: String.Pos)
| ']' => consumeCloseBracket Bracket.Square startPos i input bs ctx s
| '>' => consumeCloseBracket Bracket.Angle startPos i input bs ctx s
| '}' => consumeCloseBracket Bracket.Curly startPos i input bs ctx s
| c => balancedBracketsFnAux startPos (input.next i) input bs ctx s
| _c => balancedBracketsFnAux startPos (input.next i) input bs ctx s

end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
-- should replace with Lean import once Pure is upstream
import SSA.Projects.MLIRSyntax.AST
import SSA.Projects.InstCombine.LLVM.Transform.NameMapping
import SSA.Projects.InstCombine.LLVM.Transform.TransformError
import SSA.Core.MLIRSyntax.AST
import SSA.Core.MLIRSyntax.Transform.NameMapping
import SSA.Core.MLIRSyntax.Transform.TransformError
import SSA.Core.Framework
import SSA.Core.ErasedContext

import Std.Data.BitVec
/-!
# `Transform*` typeclasses
This file defines `TransformTy`, `TransformExpr`, and `TransformReturn` typeclasses,
which dictate how generic MLIR syntax (as defined in `MLIRSyntax.AST`) can be transformed into
an instance of `Com` or `Expr` for a specific dialect.
-/

universe u

namespace MLIR.AST

open Std (BitVec)
open Ctxt

instance {Op Ty : Type} [OpSignature Op Ty] {t : Ty} {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import SSA.Projects.MLIRSyntax.AST
import SSA.Core.MLIRSyntax.AST

namespace MLIR.AST

Expand Down
14 changes: 13 additions & 1 deletion SSA/Core/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ section

open Lean Meta Elab.Tactic Qq

/-- `ctxtNf` reduces an expression of type `Ctxt _` to something in between whnf and normal form.
`ctxtNf` recursively calls `whnf` on the tail of the list, so that the result is of the form
`a₀ :: a₁ :: ... :: aₙ :: [] `
where each element `aᵢ` is not further reduced -/
private partial def ctxtNf {α : Q(Type)} (as : Q(Ctxt $α)) : MetaM Q(Ctxt $α) := do
let as : Q(List $α) ← whnf as
match as with
| ~q($a :: $as) =>
let as ← ctxtNf as
return q($a :: $as)
| as => return as

/-- Given a `V : Valuation Γ`, fully reduce the context `Γ` in the type of `V` -/
elab "change_mlir_context " V:ident : tactic => do
let V : Name := V.getId
Expand All @@ -28,7 +40,7 @@ elab "change_mlir_context " V:ident : tactic => do
let _ ← assertDefEqQ Vdecl.type q(@Ctxt.Valuation $Ty $G $Γ)

-- Reduce the context `Γ`
let Γr ← reduce Γ
let Γr ← ctxtNf Γ
let Γr : Q(Ctxt $Ty) := Γr

let goal ← getMainGoal
Expand Down
4 changes: 3 additions & 1 deletion SSA/Projects/FullyHomomorphicEncryption.lean
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
import SSA.Projects.FullyHomomorphicEncryption.Basic
import SSA.Projects.FullyHomomorphicEncryption.Statements
import SSA.Projects.FullyHomomorphicEncryption.Rewrites
import SSA.Projects.FullyHomomorphicEncryption.Statements
import SSA.Projects.FullyHomomorphicEncryption.Syntax
40 changes: 25 additions & 15 deletions SSA/Projects/FullyHomomorphicEncryption/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ For the rationale behind this, see:
Junfeng Fan and Frederik Vercauteren, Somewhat Practical Fully Homomorphic Encryption
https://eprint.iacr.org/2012/144
Authors: Andrés Goens<andres@goens.org>, Siddharth Bhat<siddu.druid@gmail.com>
-/
import Mathlib.RingTheory.Polynomial.Quotient
import Mathlib.RingTheory.Ideal.Quotient
Expand Down Expand Up @@ -639,7 +640,7 @@ inductive Ty (q : Nat) (n : Nat) [Fact (q > 1)]
| integer : Ty q n
| tensor : Ty q n
| polynomialLike : Ty q n
deriving DecidableEq
deriving DecidableEq, Repr

instance : Inhabited (Ty q n) := ⟨Ty.index⟩
instance : TyDenote (Ty q n) where
Expand All @@ -666,6 +667,9 @@ inductive Op (q : Nat) (n : Nat) [Fact (q > 1)]
| from_tensor : Op q n-- interpret values as coefficients of a representative
| to_tensor : Op q n-- give back coefficients from `R.representative`
| const (c : R q n) : Op q n
| const_int (c : Int) : Op q n
| const_idx (i : Nat) : Op q n


open TyDenote (toType)

Expand All @@ -682,12 +686,17 @@ def Op.sig : Op q n → List (Ty q n)
| Op.from_tensor => [Ty.tensor]
| Op.to_tensor => [Ty.polynomialLike]
| Op.const _ => []
| Op.const_int _ => []
| Op.const_idx _ => []


@[simp, reducible]
def Op.outTy : Op q n → Ty q n
| Op.add | Op.sub | Op.mul | Op.mul_constant | Op.leading_term | Op.monomial
| Op.monomial_mul | Op.from_tensor | Op.const _ => Ty.polynomialLike
| Op.to_tensor => Ty.tensor
| Op.const_int _ => Ty.integer
| Op.const_idx _ => Ty.index

@[simp, reducible]
def Op.signature : Op q n → Signature (Ty q n) :=
Expand All @@ -696,17 +705,18 @@ def Op.signature : Op q n → Signature (Ty q n) :=
instance : OpSignature (Op q n) (Ty q n) := ⟨Op.signature q n⟩

@[simp]
noncomputable def Op.denote (o : Op q n)
(arg : HVector toType (OpSignature.sig o))
: (toType <| OpSignature.outTy o) :=
match o with
| Op.add => (fun args : R q n × R q n => args.1 + args.2) arg.toPair
| Op.sub => (fun args : R q n × R q n => args.1 - args.2) arg.toPair
| Op.mul => (fun args : R q n × R q n => args.1 * args.2) arg.toPair
| Op.mul_constant => (fun args : R q n × Int => args.1 * ↑(args.2)) arg.toPair
| Op.leading_term => R.leadingTerm arg.toSingle
| Op.monomial => (fun args => R.monomial ↑(args.1) args.2) arg.toPair
| Op.monomial_mul => (fun args : R q n × Nat => args.1 * R.monomial 1 args.2) arg.toPair
| Op.from_tensor => R.fromTensor arg.toSingle
| Op.to_tensor => R.toTensor' arg.toSingle
| Op.const c => c
noncomputable instance : OpDenote (Op q n) (Ty q n) where
denote
| Op.add, arg, _ => (fun args : R q n × R q n => args.1 + args.2) arg.toPair
| Op.sub, arg, _ => (fun args : R q n × R q n => args.1 - args.2) arg.toPair
| Op.mul, arg, _ => (fun args : R q n × R q n => args.1 * args.2) arg.toPair
| Op.mul_constant, arg, _ => (fun args : R q n × Int => args.1 * ↑(args.2)) arg.toPair
| Op.leading_term, arg, _ => R.leadingTerm arg.toSingle
| Op.monomial, arg, _ => (fun args => R.monomial ↑(args.1) args.2) arg.toPair
| Op.monomial_mul, arg, _ => (fun args : R q n × Nat => args.1 * R.monomial 1 args.2) arg.toPair
| Op.from_tensor, arg, _ => R.fromTensor arg.toSingle
| Op.to_tensor, arg, _ => R.toTensor' arg.toSingle
| Op.const c, _arg, _ => c
| Op.const_int c, _, _ => c
| Op.const_idx c, _, _ => c

Loading

0 comments on commit 38574a1

Please sign in to comment.