From 160cdb8e83eaf003558fd3c581ea51ce7c0f9fbd Mon Sep 17 00:00:00 2001 From: lfrenot Date: Fri, 25 Oct 2024 13:08:52 +0100 Subject: [PATCH] feat: Adding the integer conversion operations to the llvm dialect (#721) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tobias Grosser Co-authored-by: Léo Stefanesco --- SSA/Projects/InstCombine/Base.lean | 82 ++++++++++++++----- SSA/Projects/InstCombine/LLVM/EDSL.lean | 27 ++++-- SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean | 20 +++++ SSA/Projects/InstCombine/LLVM/Semantics.lean | 30 ++++++- 4 files changed, 128 insertions(+), 31 deletions(-) diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index bbdeff5e7..4c7049595 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -80,10 +80,13 @@ instance : Repr (BitVec n) where | v, n => reprPrec (BitVec.toInt v) n /-- Homogeneous, unary operations -/ -inductive MOp.UnaryOp : Type +inductive MOp.UnaryOp (φ : Nat) : Type | neg | not | copy + | trunc (w' : Width φ) + | zext (w' : Width φ) + | sext (w' : Width φ) deriving Repr, DecidableEq, Inhabited /-- Homogeneous, binary operations -/ @@ -142,7 +145,7 @@ instance : Repr (MOp.BinaryOp) where -- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations inductive MOp (φ : Nat) : Type - | unary (w : Width φ) (op : MOp.UnaryOp) : MOp φ + | unary (w : Width φ) (op : MOp.UnaryOp φ) : MOp φ | binary (w : Width φ) (op : MOp.BinaryOp) : MOp φ | select (w : Width φ) : MOp φ | icmp (c : IntPredicate) (w : Width φ) : MOp φ @@ -155,6 +158,10 @@ namespace MOp @[match_pattern] def neg (w : Width φ) : MOp φ := .unary w .neg @[match_pattern] def not (w : Width φ) : MOp φ := .unary w .not @[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy +@[match_pattern] def trunc (w w' : Width φ) : MOp φ := .unary w (.trunc w') +@[match_pattern] def zext (w w' : Width φ) : MOp φ := .unary w (.zext w') +@[match_pattern] def sext (w w' : Width φ) : MOp φ := .unary w (.sext w') + @[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and @[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor @@ -198,6 +205,9 @@ namespace MOp def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (neg : ∀ {φ} {w : Width φ}, motive (neg w)) (not : ∀ {φ} {w : Width φ}, motive (not w)) + (trunc : ∀ {φ} {w w' : Width φ}, motive (trunc w w')) + (zext : ∀ {φ} {w w' : Width φ}, motive (zext w w')) + (sext : ∀ {φ} {w w' : Width φ}, motive (sext w w')) (copy : ∀ {φ} {w : Width φ}, motive (copy w)) (and : ∀ {φ} {w : Width φ}, motive (and w)) (or : ∀ {φ DisjointFlag} {w : Width φ}, motive (or w DisjointFlag)) @@ -218,6 +228,9 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} ∀ {φ} (op : MOp φ), motive op | _, .neg _ => neg | _, .not _ => not + | _, .trunc _ _ => trunc + | _, .zext _ _ => zext + | _, .sext _ _ => sext | _, .copy _ => copy | _, .and _ => and | _, .or _ _ => or @@ -255,6 +268,9 @@ instance : ToString (MOp φ) where | .sub _ _ => "sub" | .neg _ => "neg" | .copy _ => "copy" + | .trunc _ _ => "trunc" + | .zext _ _ => "zext" + | .sext _ _ => "sext" | .sdiv _ _ => "sdiv" | .udiv _ _ => "udiv" | .icmp ty _ => s!"icmp {ty}" @@ -262,13 +278,24 @@ instance : ToString (MOp φ) where abbrev Op := MOp 0 +def MOp.UnaryOp.instantiate (as : Mathlib.Vector Nat φ) : MOp.UnaryOp φ → MOp.UnaryOp 0 +| .trunc w' => .trunc (.concrete <| w'.instantiate as) +| .zext w' => .zext (.concrete <| w'.instantiate as) +| .sext w' => .sext (.concrete <| w'.instantiate as) +| .neg => .neg +| .not => .not +| .copy => .copy + namespace Op -@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp) : Op := MOp.unary (.concrete w) op +@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp 0) : Op := MOp.unary (.concrete w) op @[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op @[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete @[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete +@[match_pattern] abbrev trunc : Nat → Nat → Op := fun w w' => MOp.trunc (.concrete w) (.concrete w') +@[match_pattern] abbrev zext : Nat → Nat → Op := fun w w' => MOp.zext (.concrete w) (.concrete w') +@[match_pattern] abbrev sext : Nat → Nat → Op := fun w w' => MOp.sext (.concrete w) (.concrete w') @[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete @[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete @[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete @@ -311,10 +338,18 @@ def MOp.sig : MOp φ → List (MTy φ) | .select w => [.bitvec 1, .bitvec w, .bitvec w] | .const _ _ => [] +@[simp, reducible] +def MOp.UnaryOp.outTy (w : Width φ) : MOp.UnaryOp φ → MTy φ +| .trunc w' => .bitvec w' +| .zext w' => .bitvec w' +| .sext w' => .bitvec w' +| _ => .bitvec w + @[simp, reducible] def MOp.outTy : MOp φ → MTy φ -| .binary w _ | .unary w _ | .select w | .const w _ => +| .binary w _ | .select w | .const w _ => .bitvec w +| .unary w op => op.outTy w | .icmp _ _ => .bitvec 1 /-- `MetaLLVM φ` is the `LLVM` dialect with at most `φ` metavariables -/ @@ -336,24 +371,27 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig (TyDenote.toType <| DialectSignature.outTy o) := match o with | Op.const _ val => const? val - | Op.copy _ => (op.getN 0) - | Op.not _ => LLVM.not (op.getN 0) - | Op.neg _ => LLVM.neg (op.getN 0) - | Op.and _ => LLVM.and (op.getN 0) (op.getN 1) - | Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag - | Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1) - | Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags - | Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag - | Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag - | Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags - | Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags - | Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags - | Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag - | Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag - | Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1) - | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) - | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) - | Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2) + | Op.copy _ => (op.getN 0) + | Op.not _ => LLVM.not (op.getN 0) + | Op.neg _ => LLVM.neg (op.getN 0) + | Op.trunc w w' => LLVM.trunc w' (op.getN 0) + | Op.zext w w' => LLVM.zext w' (op.getN 0) + | Op.sext w w' => LLVM.sext w' (op.getN 0) + | Op.and _ => LLVM.and (op.getN 0) (op.getN 1) + | Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag + | Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1) + | Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags + | Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag + | Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag + | Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags + | Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags + | Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags + | Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag + | Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag + | Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1) + | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) + | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) + | Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2) instance : DialectDenote LLVM := ⟨ fun o args _ => Op.denote o args diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 53a967e4a..3ca08ba2d 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -13,8 +13,8 @@ open MLIR namespace InstcombineTransformDialect -def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp) - (e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (.bitvec w) := +def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp φ) + (e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (op.outTy w) := ⟨ .unary w op, rfl, @@ -70,6 +70,15 @@ def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM (MetaLLVM φ) ((MetaLLVM φ instance instTransformTy : MLIR.AST.TransformTy (MetaLLVM φ) φ where mkTy := mkTy +def getOutputWidth (opStx : MLIR.AST.Op φ) (op : String) : + AST.ReaderM (MetaLLVM φ) (Width φ) := do + match opStx.res with + | res::[] => + match res.2 with + | .int _ w => pure w + | _ => throw <| .generic s!"The operation {op} must output an integer type" + | _ => throw <| .generic s!"The operation {op} must have a single output" + def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : AST.ReaderM (MetaLLVM φ) (Σ eff ty, Expr (MetaLLVM φ) Γ eff ty) := do match opStx.args with @@ -169,9 +178,12 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | vStx::[] => let ⟨.bitvec w, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx let op ← match opStx.name with - | "llvm.not" => pure .not - | "llvm.neg" => pure .neg - | "llvm.copy" => pure .copy + | "llvm.not" => pure .not + | "llvm.neg" => pure .neg + | "llvm.copy" => pure .copy + | "llvm.trunc" => pure <| .trunc (← getOutputWidth opStx "trunc") + | "llvm.zext" => pure <| .zext (← getOutputWidth opStx "zext") + | "llvm.sext" => pure <| .sext (← getOutputWidth opStx "sext") | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" return ⟨_, _, mkUnaryOp op v⟩ | [] => @@ -222,7 +234,7 @@ def instantiateMTy (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Ty → LLVM.Ty def instantiateMOp (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Op → LLVM.Op | .binary w binOp => .binary (w.instantiate vals) binOp - | .unary w unOp => .unary (w.instantiate vals) unOp + | .unary w unOp => .unary (w.instantiate vals) (unOp.instantiate vals) | .select w => .select (w.instantiate vals) | .icmp c w => .icmp c (w.instantiate vals) | .const w val => .const (w.instantiate vals) val @@ -238,12 +250,13 @@ def MOp.instantiateCom (vals : Mathlib.Vector Nat φ) : DialectMorphism (MetaLLV preserves_signature op := by have h1 : ∀ (φ : Nat), 1 = ConcreteOrMVar.concrete (φ := φ) 1 := by intros φ; rfl cases op <;> + (try casesm MOp.UnaryOp _) <;> simp only [instantiateMTy, instantiateMOp, ConcreteOrMVar.instantiate, (· <$> ·), signature, InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk, Signature.mkEffectful.injEq, List.map_cons, List.map_nil, and_self, MTy.bitvec, List.cons.injEq, MTy.bitvec.injEq, and_true, true_and, - RegionSignature.map, Signature.map, h1] + RegionSignature.map, Signature.map, MOp.UnaryOp.instantiate, MOp.UnaryOp.outTy, h1] open InstCombine in def mkComInstantiate (reg : MLIR.AST.Region φ) : diff --git a/SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean b/SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean index 1ddff2ad6..6f3f1a812 100644 --- a/SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean @@ -47,6 +47,18 @@ macro_rules let t ← t.getDM `(mlir_type| _) `(mlir_op| $resName:mlir_op_operand = $opName ($x, $y) : ($t, $t) -> (i1) ) +declare_syntax_cat InstCombine.int_cast_op +syntax "llvm.trunc" : InstCombine.int_cast_op +syntax "llvm.zext" : InstCombine.int_cast_op +syntax "llvm.sext" : InstCombine.int_cast_op + +syntax mlir_op_operand " = " InstCombine.int_cast_op mlir_op_operand " : " mlir_type " to " mlir_type : mlir_op +macro_rules + | `(mlir_op| $resName:mlir_op_operand = $name:InstCombine.int_cast_op $x : $t to $t') => do + let some opName := extractOpName name.raw + | Macro.throwUnsupported + `(mlir_op| $resName:mlir_op_operand = $opName ($x) : ($t) -> $t') + open MLIR.AST syntax mlir_op_operand " = " "llvm.mlir.constant" "(" neg_num (" : " mlir_type)? ")" (" : " mlir_type)? : mlir_op @@ -145,6 +157,14 @@ private def pretty_test_overflow := llvm.return %3 : i32 }] +private def pretty_test_trunc := + [llvm ()|{ + ^bb0(%arg0: i64): + %0 = llvm.trunc %arg0 : i64 to i32 + %1 = llvm.zext %0 : i32 to i64 + llvm.return %1 : i64 + }] + example : pretty_test = prettier_test_generic 32 := by unfold pretty_test prettier_test_generic simp_alive_meta diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index ba2e9971a..0f1303e13 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -404,8 +404,6 @@ def ashr {w : Nat} (x y : IntW w) (flag : ExactFlag := {exact := false}) : IntW /-- If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument. - - If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument. -/ @[simp_llvm_option] def select {w : Nat} (c? : IntW 1) (x? y? : IntW w ) : IntW w := @@ -564,4 +562,32 @@ def neg {w : Nat} (x : IntW w) : IntW w := do let x' ← x neg? x' + +@[simp_llvm] +def trunc? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do + pure <| (BitVec.truncate w' x) + +@[simp_llvm_option] +def trunc {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do + let x' <- x + trunc? w' x' + +@[simp_llvm] +def zext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do + pure <| (BitVec.zeroExtend w' x) + +@[simp_llvm_option] +def zext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do + let x' <- x + zext? w' x' + +@[simp_llvm] +def sext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do + pure <| (BitVec.signExtend w' x) + +@[simp_llvm_option] +def sext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do + let x' <- x + sext? w' x' + end LLVM