Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding the integer conversion operations to the llvm dialect #721

Merged
merged 25 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0fdee4a
Renamed the script files
lfrenot Oct 16, 2024
43a4918
re-ran scripts
lfrenot Oct 16, 2024
48b1be2
Edits to the gen files
lfrenot Oct 16, 2024
1011dc2
re-ran the tests
lfrenot Oct 16, 2024
21fdfde
Adding a log to the test-gen
lfrenot Oct 17, 2024
55789ff
Updated test-gen to create log fiels
lfrenot Oct 18, 2024
979785c
Updated proof-gen to log build failures
lfrenot Oct 18, 2024
4bf2f93
Re-ran the sccripts
lfrenot Oct 18, 2024
044021d
Merge remote-tracking branch 'origin/main' into instcombine-test-stats
lfrenot Oct 18, 2024
e92cf94
Basic script to read the logs
lfrenot Oct 21, 2024
32be146
Merge remote-tracking branch 'origin/main' into instcombine-test-stats
lfrenot Oct 23, 2024
edb64a3
Updated tests
lfrenot Oct 23, 2024
a3f8ce4
forgot to commit the changes to proof-gen
lfrenot Oct 23, 2024
e9c8757
updated read-logs
lfrenot Oct 23, 2024
3b3b157
edits to cfg
lfrenot Oct 23, 2024
5c249d7
Update SSA/Projects/InstCombine/scripts/cfg.py
lfrenot Oct 23, 2024
ee4aa76
removed commented-out code
lfrenot Oct 23, 2024
7e096bd
Update cfg.py
lfrenot Oct 23, 2024
04c4b1c
Added newlines at end of files
lfrenot Oct 23, 2024
ced1725
Semantics for Integer Conversion ops
lfrenot Oct 24, 2024
a8f0d16
First attempt (not working)
lfrenot Oct 24, 2024
393f9e3
Merge remote-tracking branch 'origin/main' into llvm-integer-cast-ops
lfrenot Oct 24, 2024
a634e40
Merge remote-tracking branch 'origin/main' into llvm-integer-cast-ops
lfrenot Oct 24, 2024
621171d
Quick cleanups
lfrenot Oct 25, 2024
d2008f2
Fix typing of llvm conversions
ineol Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ instance : Repr (BitVec n) where
| v, n => reprPrec (BitVec.toInt v) n

/-- Homogeneous, unary operations -/
inductive MOp.UnaryOp : Type
inductive MOp.UnaryOp (φ : Nat) : Type
| neg
| not
| copy
| trunc (w' : Width φ)
| zext (w' : Width φ)
| sext (w' : Width φ)
deriving Repr, DecidableEq, Inhabited

/-- Homogeneous, binary operations -/
Expand Down Expand Up @@ -142,7 +145,7 @@ instance : Repr (MOp.BinaryOp) where

-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
inductive MOp (φ : Nat) : Type
| unary (w : Width φ) (op : MOp.UnaryOp) : MOp φ
| unary (w : Width φ) (op : MOp.UnaryOp φ) : MOp φ
| binary (w : Width φ) (op : MOp.BinaryOp) : MOp φ
| select (w : Width φ) : MOp φ
| icmp (c : IntPredicate) (w : Width φ) : MOp φ
Expand All @@ -155,6 +158,10 @@ namespace MOp
@[match_pattern] def neg (w : Width φ) : MOp φ := .unary w .neg
@[match_pattern] def not (w : Width φ) : MOp φ := .unary w .not
@[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy
@[match_pattern] def trunc (w w' : Width φ) : MOp φ := .unary w (.trunc w')
@[match_pattern] def zext (w w' : Width φ) : MOp φ := .unary w (.zext w')
@[match_pattern] def sext (w w' : Width φ) : MOp φ := .unary w (.sext w')


@[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and
@[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor
Expand Down Expand Up @@ -198,6 +205,9 @@ namespace MOp
def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(neg : ∀ {φ} {w : Width φ}, motive (neg w))
(not : ∀ {φ} {w : Width φ}, motive (not w))
(trunc : ∀ {φ} {w w' : Width φ}, motive (trunc w w'))
(zext : ∀ {φ} {w w' : Width φ}, motive (zext w w'))
(sext : ∀ {φ} {w w' : Width φ}, motive (sext w w'))
(copy : ∀ {φ} {w : Width φ}, motive (copy w))
(and : ∀ {φ} {w : Width φ}, motive (and w))
(or : ∀ {φ DisjointFlag} {w : Width φ}, motive (or w DisjointFlag))
Expand All @@ -218,6 +228,9 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
∀ {φ} (op : MOp φ), motive op
| _, .neg _ => neg
| _, .not _ => not
| _, .trunc _ _ => trunc
| _, .zext _ _ => zext
| _, .sext _ _ => sext
| _, .copy _ => copy
| _, .and _ => and
| _, .or _ _ => or
Expand Down Expand Up @@ -255,20 +268,34 @@ instance : ToString (MOp φ) where
| .sub _ _ => "sub"
| .neg _ => "neg"
| .copy _ => "copy"
| .trunc _ _ => "trunc"
| .zext _ _ => "zext"
| .sext _ _ => "sext"
| .sdiv _ _ => "sdiv"
| .udiv _ _ => "udiv"
| .icmp ty _ => s!"icmp {ty}"
| .const _ v => s!"const {v}"

abbrev Op := MOp 0

def MOp.UnaryOp.instantiate (as : Mathlib.Vector Nat φ) : MOp.UnaryOp φ → MOp.UnaryOp 0
| .trunc w' => .trunc (.concrete <| w'.instantiate as)
| .zext w' => .zext (.concrete <| w'.instantiate as)
| .sext w' => .sext (.concrete <| w'.instantiate as)
| .neg => .neg
| .not => .not
| .copy => .copy

namespace Op

@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp) : Op := MOp.unary (.concrete w) op
@[match_pattern] abbrev unary (w : Nat) (op : MOp.UnaryOp 0) : Op := MOp.unary (.concrete w) op
@[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op

@[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete
@[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete
@[match_pattern] abbrev trunc : Nat → Nat → Op := fun w w' => MOp.trunc (.concrete w) (.concrete w')
@[match_pattern] abbrev zext : Nat → Nat → Op := fun w w' => MOp.zext (.concrete w) (.concrete w')
@[match_pattern] abbrev sext : Nat → Nat → Op := fun w w' => MOp.sext (.concrete w) (.concrete w')
@[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
Expand Down Expand Up @@ -311,10 +338,18 @@ def MOp.sig : MOp φ → List (MTy φ)
| .select w => [.bitvec 1, .bitvec w, .bitvec w]
| .const _ _ => []

@[simp, reducible]
def MOp.UnaryOp.outTy (w : Width φ) : MOp.UnaryOp φ → MTy φ
| .trunc w' => .bitvec w'
| .zext w' => .bitvec w'
| .sext w' => .bitvec w'
| _ => .bitvec w

@[simp, reducible]
def MOp.outTy : MOp φ → MTy φ
| .binary w _ | .unary w _ | .select w | .const w _ =>
| .binary w _ | .select w | .const w _ =>
.bitvec w
| .unary w op => op.outTy w
| .icmp _ _ => .bitvec 1

/-- `MetaLLVM φ` is the `LLVM` dialect with at most `φ` metavariables -/
Expand All @@ -336,24 +371,27 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
(TyDenote.toType <| DialectSignature.outTy o) :=
match o with
| Op.const _ val => const? val
| Op.copy _ => (op.getN 0)
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag
| Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
| Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2)
| Op.copy _ => (op.getN 0)
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.trunc w w' => LLVM.trunc w' (op.getN 0)
| Op.zext w w' => LLVM.zext w' (op.getN 0)
| Op.sext w w' => LLVM.sext w' (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag
| Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
| Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2)

instance : DialectDenote LLVM := ⟨
fun o args _ => Op.denote o args
Expand Down
27 changes: 20 additions & 7 deletions SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ open MLIR

namespace InstcombineTransformDialect

def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp)
(e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (.bitvec w) :=
def mkUnaryOp {Γ : Ctxt (MetaLLVM φ).Ty} {w : Width φ} (op : MOp.UnaryOp φ)
(e : Ctxt.Var Γ (.bitvec w)) : Expr (MetaLLVM φ) Γ .pure (op.outTy w) :=
.unary w op,
rfl,
Expand Down Expand Up @@ -70,6 +70,15 @@ def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM (MetaLLVM φ) ((MetaLLVM φ
instance instTransformTy : MLIR.AST.TransformTy (MetaLLVM φ) φ where
mkTy := mkTy

def getOutputWidth (opStx : MLIR.AST.Op φ) (op : String) :
AST.ReaderM (MetaLLVM φ) (Width φ) := do
match opStx.res with
| res::[] =>
match res.2 with
| .int _ w => pure w
| _ => throw <| .generic s!"The operation {op} must output an integer type"
| _ => throw <| .generic s!"The operation {op} must have a single output"

def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
AST.ReaderM (MetaLLVM φ) (Σ eff ty, Expr (MetaLLVM φ) Γ eff ty) := do
match opStx.args with
Expand Down Expand Up @@ -169,9 +178,12 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| vStx::[] =>
let ⟨.bitvec w, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx
let op ← match opStx.name with
| "llvm.not" => pure .not
| "llvm.neg" => pure .neg
| "llvm.copy" => pure .copy
| "llvm.not" => pure .not
| "llvm.neg" => pure .neg
| "llvm.copy" => pure .copy
| "llvm.trunc" => pure <| .trunc (← getOutputWidth opStx "trunc")
| "llvm.zext" => pure <| .zext (← getOutputWidth opStx "zext")
| "llvm.sext" => pure <| .sext (← getOutputWidth opStx "sext")
| _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}"
return ⟨_, _, mkUnaryOp op v⟩
| [] =>
Expand Down Expand Up @@ -222,7 +234,7 @@ def instantiateMTy (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Ty → LLVM.Ty

def instantiateMOp (vals : Mathlib.Vector Nat φ) : (MetaLLVM φ).Op → LLVM.Op
| .binary w binOp => .binary (w.instantiate vals) binOp
| .unary w unOp => .unary (w.instantiate vals) unOp
| .unary w unOp => .unary (w.instantiate vals) (unOp.instantiate vals)
| .select w => .select (w.instantiate vals)
| .icmp c w => .icmp c (w.instantiate vals)
| .const w val => .const (w.instantiate vals) val
Expand All @@ -238,12 +250,13 @@ def MOp.instantiateCom (vals : Mathlib.Vector Nat φ) : DialectMorphism (MetaLLV
preserves_signature op := by
have h1 : ∀ (φ : Nat), 1 = ConcreteOrMVar.concrete (φ := φ) 1 := by intros φ; rfl
cases op <;>
(try casesm MOp.UnaryOp _) <;>
simp only [instantiateMTy, instantiateMOp, ConcreteOrMVar.instantiate, (· <$> ·), signature,
InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map,
Signature.mk, Signature.mkEffectful.injEq,
List.map_cons, List.map_nil, and_self, MTy.bitvec,
List.cons.injEq, MTy.bitvec.injEq, and_true, true_and,
RegionSignature.map, Signature.map, h1]
RegionSignature.map, Signature.map, MOp.UnaryOp.instantiate, MOp.UnaryOp.outTy, h1]

open InstCombine in
def mkComInstantiate (reg : MLIR.AST.Region φ) :
Expand Down
20 changes: 20 additions & 0 deletions SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ macro_rules
let t ← t.getDM `(mlir_type| _)
`(mlir_op| $resName:mlir_op_operand = $opName ($x, $y) : ($t, $t) -> (i1) )

declare_syntax_cat InstCombine.int_cast_op
syntax "llvm.trunc" : InstCombine.int_cast_op
syntax "llvm.zext" : InstCombine.int_cast_op
syntax "llvm.sext" : InstCombine.int_cast_op

syntax mlir_op_operand " = " InstCombine.int_cast_op mlir_op_operand " : " mlir_type " to " mlir_type : mlir_op
macro_rules
| `(mlir_op| $resName:mlir_op_operand = $name:InstCombine.int_cast_op $x : $t to $t') => do
let some opName := extractOpName name.raw
| Macro.throwUnsupported
`(mlir_op| $resName:mlir_op_operand = $opName ($x) : ($t) -> $t')

open MLIR.AST
syntax mlir_op_operand " = " "llvm.mlir.constant" "(" neg_num (" : " mlir_type)? ")"
(" : " mlir_type)? : mlir_op
Expand Down Expand Up @@ -145,6 +157,14 @@ private def pretty_test_overflow :=
llvm.return %3 : i32
}]

private def pretty_test_trunc :=
[llvm ()|{
^bb0(%arg0: i64):
%0 = llvm.trunc %arg0 : i64 to i32
%1 = llvm.zext %0 : i32 to i64
llvm.return %1 : i64
}]

example : pretty_test = prettier_test_generic 32 := by
unfold pretty_test prettier_test_generic
simp_alive_meta
Expand Down
30 changes: 28 additions & 2 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,6 @@ def ashr {w : Nat} (x y : IntW w) (flag : ExactFlag := {exact := false}) : IntW

/--
If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument.

If the condition is an i1 and it evaluates to 1, the instruction returns the first value argument; otherwise, it returns the second value argument.
-/
@[simp_llvm_option]
def select {w : Nat} (c? : IntW 1) (x? y? : IntW w ) : IntW w :=
Expand Down Expand Up @@ -564,4 +562,32 @@ def neg {w : Nat} (x : IntW w) : IntW w := do
let x' ← x
neg? x'


@[simp_llvm]
def trunc? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.truncate w' x)

@[simp_llvm_option]
def trunc {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
trunc? w' x'

@[simp_llvm]
def zext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.zeroExtend w' x)

@[simp_llvm_option]
def zext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
zext? w' x'

@[simp_llvm]
def sext? {w: Nat} (w': Nat) (x: BitVec w) : IntW w' := do
pure <| (BitVec.signExtend w' x)

@[simp_llvm_option]
def sext {w: Nat} (w': Nat) (x: IntW w) : IntW w' := do
let x' <- x
sext? w' x'

end LLVM
Loading