Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handwritten-proof AndOrXor #263

Merged
merged 13 commits into from
Apr 30, 2024
83 changes: 83 additions & 0 deletions SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import SSA.Projects.InstCombine.ComWrappers
import SSA.Projects.InstCombine.ForLean
import SSA.Projects.InstCombine.LLVM.EDSL
import SSA.Projects.InstCombine.Tactic

Expand Down Expand Up @@ -50,4 +52,85 @@ axioms: [propext, Classical.choice, Quot.sound] -/

end DivRemOfSelect

namespace AndOrXor
/-
Name: AndOrXor:2515 ((X^C1) >> C2)^C3 = (X>>C2) ^ ((C1>>C2)^C3)
%e1 = xor %x, C1
%op0 = lshr %e1, C2
%r = xor %op0, C3
=>
%o = lshr %x, C2 -- (X>>C2)
%p = lshr(%C1,%C2)
%q = xor %p, %C3 -- ((C1>>C2)^C3)
%r = xor %o, %q
-/

open ComWrappers

def AndOrXor2515_lhs (w : ℕ):
Com InstCombine.LLVM
[/- C1 -/ InstCombine.Ty.bitvec w,
/- C2 -/ InstCombine.Ty.bitvec w,
/- C3 -/ InstCombine.Ty.bitvec w,
/- %X -/ InstCombine.Ty.bitvec w] (InstCombine.Ty.bitvec w) :=
/- e1 = -/ Com.lete (xor w /-x-/ 0 /-C1-/ 3) <|
/- op0 = -/ Com.lete (lshr w /-e1-/ 0 /-C2-/ 3) <|
/- r = -/ Com.lete (xor w /-op0-/ 0 /-C3-/ 3) <|
Com.ret ⟨/-r-/0, by simp [Ctxt.snoc]⟩

def AndOrXor2515_rhs (w : ℕ) :
Com InstCombine.LLVM
[/- C1 -/ InstCombine.Ty.bitvec w,
/- C2 -/ InstCombine.Ty.bitvec w,
/- C3 -/ InstCombine.Ty.bitvec w,
/- %X -/ InstCombine.Ty.bitvec w] (InstCombine.Ty.bitvec w) :=
/- o = -/ Com.lete (lshr w /-X-/ 0 /-C2-/ 2) <|
/- p = -/ Com.lete (lshr w /-C1-/ 4 /-C2-/ 3) <|
/- q = -/ Com.lete (xor w /-p-/ 0 /-C3-/ 3) <|
/- r = -/ Com.lete (xor w /-o-/ 2 /-q-/ 0) <|
Com.ret ⟨/-r-/0, by simp [Ctxt.snoc]⟩

def alive_simplifyAndOrXor2515 (w : Nat) :
AndOrXor2515_lhs w ⊑ AndOrXor2515_rhs w := by
simp only [AndOrXor2515_lhs, AndOrXor2515_rhs]
simp only [simp_llvm_wrap]
simp_alive_ssa
simp_alive_undef
intros c1 c2 c3 x
rcases c1 with rfl | c1 <;> try (simp; done)
rcases c2 with rfl | c2 <;> try (simp; done)
rcases c3 with rfl | c3 <;> try (simp; done)
rcases x with rfl | x <;> try (simp; done)
simp_alive_ops
by_cases h : BitVec.toNat c2 ≥ w <;>
simp [h, ushr_xor_distrib, xor_assoc]

/-- info: 'AliveHandwritten.AndOrXor.alive_simplifyAndOrXor2515' depends on
axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms alive_simplifyAndOrXor2515

/-
Proof:
------
bitwise reasoning.
LHS:
----
(((X^C1) >> C2)^C3))[i]
= ((X^C1) >> C2)[i] ^ C3[i] [bit-of-lsh r]
# NOTE: negative entries will be 0 because it is LOGICAL shift right. This is denoted by the []₀ operator.
= ((X^C1))[i - C2]₀ ^ C3[i] [bit-of-lshr]
= (X[i - C2]₀ ^ C1[i - C2]₀) ^ C3[i] [bit-of-xor]
= X[i - C2]₀ ^ C1[i - C2]₀ ^ C3[i] [assoc]


RHS:
----
((X>>C2) ^ ((C1 >> C2)^C3))[i]
= (X>>C2)[i] ^ (C1 >> C2)^C3)[i] [bit-of-xor]
# NOTE: negative entries will be 0 because it is LOGICAL shift right
= X[i - C2]₀ ^ ((C1 >> C2)[i] ^ C3[i]) [bit-of-lshr]
= X[i - C2]₀ ^ (C1[i-C2] ^ C3[i]) [bit-of-lshr]
-/
end AndOrXor

end AliveHandwritten
36 changes: 36 additions & 0 deletions SSA/Projects/InstCombine/ForLean.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
namespace BitVec

def ushr_xor_distrib (a b c : BitVec w) :
(a ^^^ b) >>> c = (a >>> c) ^^^ (b >>> c) := by
simp only [HShiftRight.hShiftRight]
ext
simp

def ushr_and_distrib (a b c : BitVec w) :
(a &&& b) >>> c = (a >>> c) &&& (b >>> c) := by
simp only [HShiftRight.hShiftRight]
ext
simp

def ushr_or_distrib (a b c : BitVec w) :
(a ||| b) >>> c = (a >>> c) ||| (b >>> c) := by
simp only [HShiftRight.hShiftRight]
ext
simp

def xor_assoc (a b c : BitVec w) :
a ^^^ b ^^^ c = a ^^^ (b ^^^ c) := by
ext i
simp [Bool.xor_assoc]

def and_assoc (a b c : BitVec w) :
a &&& b &&& c = a &&& (b &&& c) := by
ext i
simp [Bool.and_assoc]

def or_assoc (a b c : BitVec w) :
a ||| b ||| c = a ||| (b ||| c) := by
ext i
simp [Bool.or_assoc]

end BitVec
26 changes: 26 additions & 0 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The ‘and’ instruction returns the bitwise logical and of its two operands.
def and? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x &&& y

@[simp_llvm_option]
theorem and?_eq : LLVM.and? a b = .some (BitVec.and a b) := rfl

@[simp_llvm_option]
def and {w : Nat} (x y : IntW w) : IntW w := do
let x' ← x
Expand All @@ -35,6 +38,8 @@ operands.
def or? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x ||| y

@[simp_llvm_option]
theorem or?_eq : LLVM.or? a b = .some (BitVec.or a b) := rfl

@[simp_llvm_option]
def or {w : Nat} (x y : IntW w) : IntW w := do
Expand All @@ -51,6 +56,9 @@ is the “~” operator in C.
def xor? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x ^^^ y

@[simp_llvm_option]
theorem xor?_eq : LLVM.xor? a b = .some (BitVec.xor a b) := rfl

@[simp_llvm_option]
def xor {w : Nat} (x y : IntW w) : IntW w := do
let x' ← x
Expand All @@ -66,6 +74,9 @@ Because LLVM integers use a two’s complement representation, this instruction
def add? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x + y

@[simp_llvm_option]
theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl

@[simp_llvm_option]
def add {w : Nat} (x y : IntW w) : IntW w := do
let x' ← x
Expand All @@ -81,6 +92,9 @@ Because LLVM integers use a two’s complement representation, this instruction
def sub? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x - y

@[simp_llvm_option]
theorem sub?_eq : LLVM.sub? a b = .some (BitVec.sub a b) := rfl

@[simp_llvm_option]
def sub {w : Nat} (x y : IntW w) : IntW w := do
let x' ← x
Expand All @@ -102,6 +116,9 @@ sign-extended or zero-extended as appropriate to the width of the full product.
def mul? {w : Nat} (x y : BitVec w) : IntW w :=
pure <| x * y

@[simp_llvm_option]
theorem mul?_eq : LLVM.mul? a b = .some (BitVec.mul a b) := rfl

@[simp_llvm_option]
def mul {w : Nat} (x y : IntW w) : IntW w := do
let x' ← x
Expand Down Expand Up @@ -411,10 +428,16 @@ TODO: double-check that truncating works the same as MLIR (signedness, overflow,
def const? (i : Int): IntW w :=
pure <| BitVec.ofInt w i

@[simp_llvm_option]
theorem LLVM.const?_eq : LLVM.const? i = .some (BitVec.ofInt w i) := rfl

@[simp_llvm]
def not? {w : Nat} (x : BitVec w) : IntW w := do
pure (~~~x)

@[simp_llvm_option]
theorem LLVM.not?_eq : LLVM.not? a = .some (BitVec.not a) := rfl

@[simp_llvm_option]
def not {w : Nat} (x : IntW w) : IntW w := do
let x' ← x
Expand All @@ -424,6 +447,9 @@ def not {w : Nat} (x : IntW w) : IntW w := do
def neg? {w : Nat} (x : BitVec w) : IntW w := do
pure <| (-.) x

@[simp_llvm_option]
theorem LLVM.neg?_eq : LLVM.neg? a = .some (BitVec.neg a) := rfl

@[simp_llvm_option]
def neg {w : Nat} (x : IntW w) : IntW w := do
let x' ← x
Expand Down
Loading