Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed May 14, 2024
1 parent e39d27c commit 738db83
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 59 deletions.
139 changes: 81 additions & 58 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -190,47 +190,112 @@ We implement [Booth's multiplication circuit](https://en.wikipedia.org/wiki/Boot
on bitvectors, and show that this circuit is equal to our straightforward `BitVec.mul` implementation.
-/

def mulAdd (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) : BitVec (w+v) :=
let x : BitVec (w+v) := x.zeroExtend' (le_add_right w v)
Prod.snd <| iunfoldr (s:=a) fun (i : Fin (w+v)) a =>
def mulAdd (a x y : BitVec w) : BitVec w :=
Prod.snd <| iunfoldr (s:=a) fun (i : Fin w) a =>
let a := if y.getLsb i = true then a + x else a
(a >>> 1, a.getLsb 0)

def mulAddAccumulator (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) (i : Nat) : BitVec (w+v) :=
(a + (x.zeroExtend' <| le_add_right w v) * ((y.extractLsb' 0 i).zeroExtend _) ) >>> i
def mulAddAccumulator (a x y : BitVec w) (i : Nat) : BitVec w :=
(a + x * (y.truncate i |>.zeroExtend _)) >>> i

@[simp] theorem truncate_zero : truncate 0 x = 0#0 := of_length_zero

@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl
@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
@[simp] theorem shiftLeft_zero (x : BitVec w) : x <<< 0 = x := by apply eq_of_toNat_eq; simp

theorem mulAddAccumulator_zero (a x y : BitVec w) : mulAddAccumulator a x y 0 = a := by
simp [mulAddAccumulator]

theorem Nat.shiftRight_add' (m n k : Nat) :
m >>> n + k = (m + (k <<< n)) >>> n := by
sorry

theorem shiftRight_add' (x y : BitVec w) (n : Nat) :
x >>> n + y = (x + (y <<< n)) >>> n := by
sorry

#check BitVec.shiftRight_shiftRight

theorem zeroExtend_truncate_eq_and (x : BitVec w) (i : Nat) :
zeroExtend w (x.truncate i) = x &&& ((-1 : BitVec _) >>> (w-i)) := by
sorry

theorem add_shiftRight (x y : BitVec w) (n : Nat) : (x + y) >>> n = (x >>> n) + (y >>> n) := by
sorry

@[simp] theorem zero_shiftRight (w n : Nat) : 0#w >>> n = 0#w := by
sorry

theorem mod_two_pow_shiftRight (x m n : Nat) : (x % 2^m) >>> n = (x >>> n) % (2^(m+n)) := by
induction n
case zero => rfl
case succ n ih =>
simp [shiftRight_succ]
sorry

theorem shiftLeft_shiftRight_eq_zeroExtend_truncate (x : BitVec w) (i : Nat) :
x <<< i >>> i = zeroExtend w (truncate (w-i) x) := by
apply eq_of_toNat_eq
simp only [toNat_ushiftRight, toNat_shiftLeft, toNat_truncate]
induction i
case a.zero => simp
case a.succ i ih =>
rw [mod_two_pow_shiftRight]
sorry

theorem mulAddAccumulator_succ (a x y : BitVec w) :
mulAddAccumulator a x y (i+1)
= (mulAddAccumulator a x y i >>> 1)
+ bif y.getLsb (i+1) then (x.truncate (i+1) |>.zeroExtend _) else 0#w := by
-- ext j
simp only [mulAddAccumulator, natCast_eq_ofNat, BitVec.shiftRight_shiftRight]
have :
x * zeroExtend w (truncate (i + 1) y)
= x * zeroExtend w (truncate i y) + (bif y.getLsb (i+1) then x <<< (i+1) else 0) := by
simp [← shiftLeft_shiftRight_eq_zeroExtend_truncate]
rw [this, ← BitVec.add_assoc, add_shiftRight]
congr
cases y.getLsb (i+1)
· simp
· simp; sorry




@[simp]
theorem zeroExtend_zero_width (x : BitVec 0) : zeroExtend w x = 0#w := by
sorry

@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl
-- @[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
-- @[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl

theorem extractLsb'_succ_eq_concat (x : BitVec w) (s n : Nat) :
x.extractLsb' s (n+1) = cons (x.getLsb (s+n)) (x.extractLsb' s n) := by
sorry

theorem mulAdd_spec (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) :
mulAdd a x y = a + (x.zeroExtend' <| le_add_right w v) * (y.zeroExtend' <| le_add_left v w) := by
theorem mulAdd_spec (a x y : BitVec w):
mulAdd a x y = a + x * y := by
simp only [mulAdd]
rw [iunfoldr_replace (state := mulAddAccumulator a x y)]
· simp [mulAddAccumulator]
· simp [mulAddAccumulator, Nat.mod_one]
· intro i
simp only [mulAddAccumulator, Prod.mk.injEq]
simp only [extractLsb'_succ_eq_concat y 0 i, Nat.zero_add]
simp only [mulAddAccumulator, Prod.mk.injEq, natCast_eq_ofNat]
cases y.getLsb i <;> simp
· sorry
· sorry

@[simp] theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) :
x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by
theorem getLsb_mul (x y : BitVec w) (i : Fin w) :
(x * y).getLsb i = Bool.xor (x.getLsb i && y.getLsb i) ((mulAddAccumulator 0 x y i).getLsb 0) := by
sorry

theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) :
x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by
sorry

@[simp] theorem zeroExtend'_rfl (x : BitVec w) (h : w ≤ w := by rfl) : x.zeroExtend' h = x := rfl

@[simp]
theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by
@[simp] theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by
simp [truncate, zeroExtend]
intro h'
have h_eq : w = v := Nat.le_antisymm h h'
Expand All @@ -241,46 +306,4 @@ theorem mul_eq_mulAdd (x y : BitVec w) :
x * y = (mulAdd 0 x y).truncate w := by
simp [mulAdd_spec]

@[simp]
theorem extractLsb'_zero (x : BitVec w) : extractLsb' 0 n x = truncate n x := by
simp [extractLsb']

@[simp]
theorem extractLsb'_succ_concat : extractLsb' (start+1) n (concat x a) = extractLsb' start n x := by
simp [extractLsb']
sorry

-- theorem mulAdd_eq

theorem mul_eq_mulAdd (x y : BitVec w) :
x * y = (mulAdd 0 x y).truncate _ := by
suffices ∀ {v w} (x : BitVec (w+v)) (y : BitVec w) (z : BitVec v),
x * (y ++ z) = (mulAdd (x*z) x y).truncate _
by simpa using @this 0 w x y 0
induction w
case zero =>
sorry
case succ w ih =>
have ⟨x, x₀, hx⟩ : ∃ (x' : BitVec w) (x₀ : Bool), x = BitVec.concat x' x₀ := sorry
have ⟨y, y₀, hy⟩ : ∃ (y' : BitVec w) (y₀ : Bool), y = BitVec.concat y' y₀ := sorry
subst hx hy
cases y₀ <;> simp [mulAdd]
· simp [extractLsb']
rw [show 0#w = 0 from rfl, ← ih]
· sorry

def mulC (x y : BitVec w) : BitVec w :=
go _
where go (acc : BitVec w) (x y : BitVec w) : BitVec w

-- def boothMul (x y : BitVec w) : BitVec w :=
-- let a : BitVec (w+w+1) := x ++ (0 : BitVec (w+1))
-- let s : BitVec (w+w+1) := -x ++ (0 : BitVec (w+1))
-- let p : BitVec (w+w+1) := (0 : BitVec w) ++ (y : BitVec w) ++ (0 : BitVec 1)
-- go a s p w
-- where
-- go (a s p : BitVec (w+w+1)) : Nat → BitVec w
-- | 0 => p
-- | n+1 =>

end BitVec
2 changes: 1 addition & 1 deletion src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ theorem msb_append {x : BitVec w} {y : BitVec v} :
simp only [getLsb_append, cond_eq_if]
split <;> simp [*]

theorem BitVec.shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) :
theorem shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) :
(x >>> n) >>> m = x >>> (n + m) := by
ext i
simp [Nat.add_assoc n m i]
Expand Down

0 comments on commit 738db83

Please sign in to comment.