Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handwritten-proof AndOrXor #263

Merged
merged 13 commits into from
Apr 30, 2024
109 changes: 109 additions & 0 deletions SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import SSA.Projects.InstCombine.ComWrappers
import SSA.Projects.InstCombine.LLVM.EDSL
import SSA.Projects.InstCombine.Tactic

Expand All @@ -6,6 +7,16 @@ open MLIR AST

namespace AliveHandwritten

namespace LLVMTheory

@[simp]
theorem LLVM.const?_eq : LLVM.const? i = .some (BitVec.ofInt w i) := rfl

@[simp]
theorem LLVM.xor?_eq : LLVM.xor? a b = .some (BitVec.xor a b) := rfl

end LLVMTheory

tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved
namespace DivRemOfSelect

/--
Expand Down Expand Up @@ -50,4 +61,102 @@ axioms: [propext, Classical.choice, Quot.sound] -/

end DivRemOfSelect

namespace AndOrXor

open LLVMTheory

/-
Name: AndOrXor:2515 ((X^C1) >> C2)^C3 = (X>>C2) ^ ((C1>>C2)^C3)
%e1 = xor %x, C1
%op0 = lshr %e1, C2
%r = xor %op0, C3
=>
%o = lshr %x, C2 -- (X>>C2)
%p = lshr(%C1,%C2)
%q = xor %p, %C3 -- ((C1>>C2)^C3)
%r = xor %o, %q
-/

open ComWrappers
def AndOrXor2515_lhs (w : ℕ):
Com InstCombine.LLVM
[/- C1 -/ InstCombine.Ty.bitvec w,
/- C2 -/ InstCombine.Ty.bitvec w,
/- C3 -/ InstCombine.Ty.bitvec w,
/- %X -/ InstCombine.Ty.bitvec w] (InstCombine.Ty.bitvec w) :=
/- e1 = -/ Com.lete (xor w /-x-/ 0 /-C1-/ 3) <|
/- op0 = -/ Com.lete (lshr w /-e1-/ 0 /-C2-/ 3) <|
/- r = -/ Com.lete (xor w /-op0-/ 0 /-C3-/ 3) <|
Com.ret ⟨/-r-/0, by simp [Ctxt.snoc]⟩

def AndOrXor2515_rhs (w : ℕ) :
Com InstCombine.LLVM
[/- C1 -/ InstCombine.Ty.bitvec w,
/- C2 -/ InstCombine.Ty.bitvec w,
/- C3 -/ InstCombine.Ty.bitvec w,
/- %X -/ InstCombine.Ty.bitvec w] (InstCombine.Ty.bitvec w) :=
/- o = -/ Com.lete (lshr w /-X-/ 0 /-C2-/ 2) <|
/- p = -/ Com.lete (lshr w /-C1-/ 4 /-C2-/ 3) <|
/- q = -/ Com.lete (xor w /-p-/ 0 /-C3-/ 3) <|
/- r = -/ Com.lete (xor w /-o-/ 2 /-q-/ 0) <|
Com.ret ⟨/-r-/0, by simp [Ctxt.snoc]⟩

def ushr_xor_right_distrib (c1 c2 c3 : BitVec w): (c1 ^^^ c2) >>> c3 = (c1 >>> c3) ^^^ (c2 >>> c3) := by
unfold HShiftRight.hShiftRight instHShiftRightBitVec
ext
simp [getLsb_ushiftRight]

def xor_assoc (c1 c2 c3 : BitVec w): c1 ^^^ c2 ^^^ c3 = c1 ^^^ (c2 ^^^ c3) := by
ext i
simp

def help'' {a b : Nat } : a = b ↔ (↑a : Int) = ↑b := by
simp only [Nat.cast_inj]

def alive_simplifyAndOrXor2515 (w : Nat) :
AndOrXor2515_lhs w ⊑ AndOrXor2515_rhs w := by
simp only [AndOrXor2515_lhs, AndOrXor2515_rhs]
simp only [simp_llvm_wrap]
simp_alive_ssa
simp_alive_undef
intros c1 c2 c3 x
rcases c1 with none | c1 <;>
rcases c2 with none | c2 <;>
rcases c3 with none | c3 <;>
rcases x with none | x <;>
simp only [LLVM.xor?_eq, xor_eq, Option.bind_eq_bind, Option.none_bind, Option.bind_none,
Option.some_bind, Refinement.refl]
rw [←Option.bind_eq_bind]
simp_alive_ops
by_cases h : w ≤ BitVec.toNat c2 <;>
simp only [ge_iff_le, h, ↓reduceIte, Option.bind_eq_bind, Option.none_bind, Option.bind_none,
Refinement.refl, Option.some_bind, h, Option.pure_def, Option.some_bind, Refinement.some_some]
simp only [ushr_xor_right_distrib, xor_assoc]
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved

#print axioms alive_simplifyAndOrXor2515
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved

/-
Proof:
------
bitwise reasoning.
LHS:
----
(((X^C1) >> C2)^C3))[i]
= ((X^C1) >> C2)[i] ^ C3[i] [bit-of-lsh r]
# NOTE: negative entries will be 0 because it is LOGICAL shift right. This is denoted by the []₀ operator.
= ((X^C1))[i - C2]₀ ^ C3[i] [bit-of-lshr]
= (X[i - C2]₀ ^ C1[i - C2]₀) ^ C3[i] [bit-of-xor]
= X[i - C2]₀ ^ C1[i - C2]₀ ^ C3[i] [assoc]


RHS:
----
((X>>C2) ^ ((C1 >> C2)^C3))[i]
= (X>>C2)[i] ^ (C1 >> C2)^C3)[i] [bit-of-xor]
# NOTE: negative entries will be 0 because it is LOGICAL shift right
= X[i - C2]₀ ^ ((C1 >> C2)[i] ^ C3[i]) [bit-of-lshr]
= X[i - C2]₀ ^ (C1[i-C2] ^ C3[i]) [bit-of-lshr]
-/
end AndOrXor

end AliveHandwritten
Loading