Skip to content

Commit

Permalink
feat: Adding the integer conversion operations to the llvm dialect (#721
Browse files Browse the repository at this point in the history
)

Co-authored-by: Tobias Grosser <tobias@grosser.es>
Co-authored-by: Léo Stefanesco <leo.lveb@gmail.com>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent a4f5c21 commit 160cdb8
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 31 deletions.
82 changes: 60 additions & 22 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ instance : Repr (BitVec n) where
| v, n => reprPrec (BitVec.toInt v) n

/-- Homogeneous, unary operations -/
inductive MOp.UnaryOp : Type
inductive MOp.UnaryOp (φ : Nat) : Type
| neg
| not
| copy
| trunc (w' : Width φ)
| zext (w' : Width φ)
| sext (w' : Width φ)
deriving Repr, DecidableEq, Inhabited

/-- Homogeneous, binary operations -/
Expand Down Expand Up @@ -142,7 +145,7 @@ instance : Repr (MOp.BinaryOp) where

-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
inductive MOp (φ : Nat) : Type
| unary (w : Width φ) (op : MOp.UnaryOp) : MOp φ
| unary (w : Width φ) (op : MOp.UnaryOp φ) : MOp φ
| binary (w : Width φ) (op : MOp.BinaryOp) : MOp φ
| select (w : Width φ) : MOp φ
| icmp (c : IntPredicate) (w : Width φ) : MOp φ
Expand All @@ -155,6 +158,10 @@ namespace MOp
@[match_pattern] def neg (w : Width φ) : MOp φ := .unary w .neg
@[match_pattern] def not (w : Width φ) : MOp φ := .unary w .not
@[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy
@[match_pattern] def trunc (w w' : Width φ) : MOp φ := .unary w (.trunc w')
@[match_pattern] def zext (w w' : Width φ) : MOp φ := .unary w (.zext w')
@[match_pattern] def sext (w w' : Width φ) : MOp φ := .unary w (.sext w')


@[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and
@[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor
Expand Down Expand Up @@ -198,6 +205,9 @@ namespace MOp
def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(neg : ∀ {φ} {w : Width φ}, motive (neg w))
(not : ∀ {φ} {w : Width φ}, motive (not w))
(trunc : ∀ {φ} {w w' : Width φ}, motive (trunc w w'))
(zext : ∀ {φ} {w w' : Width φ}, motive (zext w w'))
(sext : ∀ {φ} {w w' : Width φ}, motive (sext w w'))
(copy : ∀ {φ} {w : Width φ}, motive (copy w))
(and : ∀ {φ} {w : Width φ}, motive (and w))
(or : ∀ {φ DisjointFlag} {w : Width φ}, motive (or w DisjointFlag))
Expand All @@ -218,6 +228,9 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
∀ {φ} (op : MOp φ), motive op
| _, .neg _ => neg
| _, .not _ => not
| _, .trunc _ _ => trunc
| _, .zext _ _ => zext
| _, .sext _ _ => sext
| _, .copy _ => copy
| _, .and _ => and
| _, .or _ _ => or
Expand Down Expand Up @@ -255,20 +268,34 @@ instance : ToString (MOp φ) where
| .sub _ _ => "sub"
| .neg _ => "neg"
| .copy _ => "copy"
| .trunc _ _ => "trunc"
| .zext _ _ => "zext"
| .sext _ _ => "sext"
| .sdiv _ _ => "sdiv"
| .udiv _ _ => "udiv"
| .icmp ty _ => s!"icmp {ty}"
| .const _ v => s!"const {v}"

abbrev Op := MOp 0

def MOp.UnaryOp.instantiate (as : Mathlib.Vector Nat φ) : MOp.UnaryOp φ → MOp.UnaryOp 0
| .trunc w' => .trunc (.concrete <| w'.instantiate as)
| .zext w' => .zext (.concrete <| w'.instantiate as)
| .sext w' => .sext (.concrete <| w'.instantiate as)
| .neg => .neg
| .not => .not
| .copy => .copy

namespace Op

@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp) : Op := MOp.unary (.concrete w) op
@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp 0) : Op := MOp.unary (.concrete w) op
@[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op

@[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete
@[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete
@[match_pattern] abbrev trunc : Nat → Nat → Op := fun w w' => MOp.trunc (.concrete w) (.concrete w')
@[match_pattern] abbrev zext : Nat → Nat → Op := fun w w' => MOp.zext (.concrete w) (.concrete w')
@[match_pattern] abbrev sext : Nat → Nat → Op := fun w w' => MOp.sext (.concrete w) (.concrete w')
@[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
Expand Down Expand Up @@ -311,10 +338,18 @@ def MOp.sig : MOp φ → List (MTy φ)
| .select w => [.bitvec 1, .bitvec w, .bitvec w]
| .const _ _ => []

@[simp, reducible]
def MOp.UnaryOp.outTy (w : Width φ) : MOp.UnaryOp φ → MTy φ
| .trunc w' => .bitvec w'
| .zext w' => .bitvec w'
| .sext w' => .bitvec w'
| _ => .bitvec w

@[simp, reducible]
def MOp.outTy : MOp φ → MTy φ
| .binary w _ | .unary w _ | .select w | .const w _ =>
| .binary w _ | .select w | .const w _ =>
.bitvec w
| .unary w op => op.outTy w
| .icmp _ _ => .bitvec 1

/-- `MetaLLVM φ` is the `LLVM` dialect with at most `φ` metavariables -/
Expand All @@ -336,24 +371,27 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
(TyDenote.toType <| DialectSignature.outTy o) :=
match o with
| Op.const _ val => const? val
| Op.copy _ => (op.getN 0)
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag
| Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
| Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2)
| Op.copy _ => (op.getN 0)
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.trunc w w' => LLVM.trunc w' (op.getN 0)
| Op.zext w w' => LLVM.zext w' (op.getN 0)
| Op.sext w w' => LLVM.sext w' (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag
| Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
| Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2)

instance : DialectDenote LLVM := ⟨
fun o args _ => Op.denote o args
Expand Down
27 changes: 20 additions & 7 deletions SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ open MLIR

namespace InstcombineTransformDialect

def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp)
(e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (.bitvec w) :=
def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp φ)
(e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (op.outTy w) :=
.unary w op,
rfl,
Expand Down Expand Up @@ -70,6 +70,15 @@ def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM (MetaLLVM φ) ((MetaLLVM φ
instance instTransformTy : MLIR.AST.TransformTy (MetaLLVM φ) φ where
mkTy := mkTy

def getOutputWidth (opStx : MLIR.AST.Op φ) (op : String) :
AST.ReaderM (MetaLLVM φ) (Width φ) := do
match opStx.res with
| res::[] =>
match res.2 with
| .int _ w => pure w
| _ => throw <| .generic s!"The operation {op} must output an integer type"
| _ => throw <| .generic s!"The operation {op} must have a single output"

def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
AST.ReaderM (MetaLLVM φ) (Σ eff ty, Expr (MetaLLVM φ) Γ eff ty) := do
match opStx.args with
Expand Down Expand Up @@ -169,9 +178,12 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| vStx::[] =>
let ⟨.bitvec w, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx
let op ← match opStx.name with
| "llvm.not" => pure .not
| "llvm.neg" => pure .neg
| "llvm.copy" => pure .copy
| "llvm.not" => pure .not
| "llvm.neg" => pure .neg
| "llvm.copy" => pure .copy
| "llvm.trunc" => pure <| .trunc (← getOutputWidth opStx "trunc")
| "llvm.zext" => pure <| .zext (← getOutputWidth opStx "zext")
| "llvm.sext" => pure <| .sext (← getOutputWidth opStx "sext")
| _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}"
return ⟨_, _, mkUnaryOp op v⟩
| [] =>
Expand Down Expand Up @@ -222,7 +234,7 @@ def instantiateMTy (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Ty → LLVM.Ty

def instantiateMOp (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Op → LLVM.Op
| .binary w binOp => .binary (w.instantiate vals) binOp
| .unary w unOp => .unary (w.instantiate vals) unOp
| .unary w unOp => .unary (w.instantiate vals) (unOp.instantiate vals)
| .select w => .select (w.instantiate vals)
| .icmp c w => .icmp c (w.instantiate vals)
| .const w val => .const (w.instantiate vals) val
Expand All @@ -238,12 +250,13 @@ def MOp.instantiateCom (vals : Mathlib.Vector Nat φ) : DialectMorphism (MetaLLV
preserves_signature op := by
have h1 : ∀ (φ : Nat), 1 = ConcreteOrMVar.concrete (φ := φ) 1 := by intros φ; rfl
cases op <;>
(try casesm MOp.UnaryOp _) <;>
simp only [instantiateMTy, instantiateMOp, ConcreteOrMVar.instantiate, (· <$> ·), signature,
InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map,
Signature.mk, Signature.mkEffectful.injEq,
List.map_cons, List.map_nil, and_self, MTy.bitvec,
List.cons.injEq, MTy.bitvec.injEq, and_true, true_and,
RegionSignature.map, Signature.map, h1]
RegionSignature.map, Signature.map, MOp.UnaryOp.instantiate, MOp.UnaryOp.outTy, h1]

open InstCombine in
def mkComInstantiate (reg : MLIR.AST.Region φ) :
Expand Down
20 changes: 20 additions & 0 deletions SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ macro_rules
let t ← t.getDM `(mlir_type| _)
`(mlir_op| $resName:mlir_op_operand = $opName ($x, $y) : ($t, $t) -> (i1) )

declare_syntax_cat InstCombine.int_cast_op
syntax "llvm.trunc" : InstCombine.int_cast_op
syntax "llvm.zext" : InstCombine.int_cast_op
syntax "llvm.sext" : InstCombine.int_cast_op

syntax mlir_op_operand " = " InstCombine.int_cast_op mlir_op_operand " : " mlir_type " to " mlir_type : mlir_op
macro_rules
| `(mlir_op| $resName:mlir_op_operand = $name:InstCombine.int_cast_op $x : $t to $t') => do
let some opName := extractOpName name.raw
| Macro.throwUnsupported
`(mlir_op| $resName:mlir_op_operand = $opName ($x) : ($t) -> $t')

open MLIR.AST
syntax mlir_op_operand " = " "llvm.mlir.constant" "(" neg_num (" : " mlir_type)? ")"
(" : " mlir_type)? : mlir_op
Expand Down Expand Up @@ -145,6 +157,14 @@ private def pretty_test_overflow :=
llvm.return %3 : i32
}]

private def pretty_test_trunc :=
[llvm ()|{
^bb0(%arg0: i64):
%0 = llvm.trunc %arg0 : i64 to i32
%1 = llvm.zext %0 : i32 to i64
llvm.return %1 : i64
}]

example : pretty_test = prettier_test_generic 32 := by
unfold pretty_test prettier_test_generic
simp_alive_meta
Expand Down
30 changes: 28 additions & 2 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,6 @@ def ashr {w : Nat} (x y : IntW w) (flag : ExactFlag := {exact := false}) : IntW

/--
If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument.
If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument.
-/
@[simp_llvm_option]
def select {w : Nat} (c? : IntW 1) (x? y? : IntW w ) : IntW w :=
Expand Down Expand Up @@ -564,4 +562,32 @@ def neg {w : Nat} (x : IntW w) : IntW w := do
let x' ← x
neg? x'


@[simp_llvm]
def trunc? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.truncate w' x)

@[simp_llvm_option]
def trunc {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
trunc? w' x'

@[simp_llvm]
def zext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.zeroExtend w' x)

@[simp_llvm_option]
def zext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
zext? w' x'

@[simp_llvm]
def sext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.signExtend w' x)

@[simp_llvm_option]
def sext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
sext? w' x'

end LLVM

0 comments on commit 160cdb8

Please sign in to comment.