diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index a9f7463b2e02..2a5e03430a10 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -190,47 +190,112 @@ We implement [Booth's multiplication circuit](https://en.wikipedia.org/wiki/Boot on bitvectors, and show that this circuit is equal to our straightforward `BitVec.mul` implementation. -/ -def mulAdd (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) : BitVec (w+v) := - let x : BitVec (w+v) := x.zeroExtend' (le_add_right w v) - Prod.snd <| iunfoldr (s:=a) fun (i : Fin (w+v)) a => +def mulAdd (a x y : BitVec w) : BitVec w := + Prod.snd <| iunfoldr (s:=a) fun (i : Fin w) a => let a := if y.getLsb i = true then a + x else a (a >>> 1, a.getLsb 0) -def mulAddAccumulator (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) (i : Nat) : BitVec (w+v) := - (a + (x.zeroExtend' <| le_add_right w v) * ((y.extractLsb' 0 i).zeroExtend _) ) >>> i +def mulAddAccumulator (a x y : BitVec w) (i : Nat) : BitVec w := + (a + x * (y.truncate i |>.zeroExtend _)) >>> i + +@[simp] theorem truncate_zero : truncate 0 x = 0#0 := of_length_zero + +@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl +@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl +@[simp] theorem shiftLeft_zero (x : BitVec w) : x <<< 0 = x := by apply eq_of_toNat_eq; simp + +theorem mulAddAccumulator_zero (a x y : BitVec w) : mulAddAccumulator a x y 0 = a := by + simp [mulAddAccumulator] + +theorem Nat.shiftRight_add' (m n k : Nat) : + m >>> n + k = (m + (k <<< n)) >>> n := by + sorry + +theorem shiftRight_add' (x y : BitVec w) (n : Nat) : + x >>> n + y = (x + (y <<< n)) >>> n := by + sorry + +#check BitVec.shiftRight_shiftRight + +theorem zeroExtend_truncate_eq_and (x : BitVec w) (i : Nat) : + zeroExtend w (x.truncate i) = x &&& ((-1 : BitVec _) >>> (w-i)) := by + sorry + +theorem add_shiftRight (x y : BitVec w) (n : Nat) : (x + y) >>> n = (x >>> n) + (y >>> n) := by + sorry + +@[simp] theorem zero_shiftRight (w n : Nat) : 0#w >>> n = 0#w := by + sorry + +theorem mod_two_pow_shiftRight (x m n : Nat) : (x % 2^m) >>> n = (x >>> n) % (2^(m+n)) := by + induction n + case zero => rfl + case succ n ih => + simp [shiftRight_succ] + sorry + +theorem shiftLeft_shiftRight_eq_zeroExtend_truncate (x : BitVec w) (i : Nat) : + x <<< i >>> i = zeroExtend w (truncate (w-i) x) := by + apply eq_of_toNat_eq + simp only [toNat_ushiftRight, toNat_shiftLeft, toNat_truncate] + induction i + case a.zero => simp + case a.succ i ih => + rw [mod_two_pow_shiftRight] + sorry + +theorem mulAddAccumulator_succ (a x y : BitVec w) : + mulAddAccumulator a x y (i+1) + = (mulAddAccumulator a x y i >>> 1) + + bif y.getLsb (i+1) then (x.truncate (i+1) |>.zeroExtend _) else 0#w := by + -- ext j + simp only [mulAddAccumulator, natCast_eq_ofNat, BitVec.shiftRight_shiftRight] + have : + x * zeroExtend w (truncate (i + 1) y) + = x * zeroExtend w (truncate i y) + (bif y.getLsb (i+1) then x <<< (i+1) else 0) := by + simp [← shiftLeft_shiftRight_eq_zeroExtend_truncate] + rw [this, ← BitVec.add_assoc, add_shiftRight] + congr + cases y.getLsb (i+1) + · simp + · simp; sorry + + + @[simp] theorem zeroExtend_zero_width (x : BitVec 0) : zeroExtend w x = 0#w := by sorry -@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl -@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl +-- @[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl +-- @[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl theorem extractLsb'_succ_eq_concat (x : BitVec w) (s n : Nat) : x.extractLsb' s (n+1) = cons (x.getLsb (s+n)) (x.extractLsb' s n) := by sorry -theorem mulAdd_spec (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) : - mulAdd a x y = a + (x.zeroExtend' <| le_add_right w v) * (y.zeroExtend' <| le_add_left v w) := by +theorem mulAdd_spec (a x y : BitVec w): + mulAdd a x y = a + x * y := by simp only [mulAdd] rw [iunfoldr_replace (state := mulAddAccumulator a x y)] - · simp [mulAddAccumulator] + · simp [mulAddAccumulator, Nat.mod_one] · intro i - simp only [mulAddAccumulator, Prod.mk.injEq] - simp only [extractLsb'_succ_eq_concat y 0 i, Nat.zero_add] + simp only [mulAddAccumulator, Prod.mk.injEq, natCast_eq_ofNat] cases y.getLsb i <;> simp · sorry · sorry -@[simp] theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) : - x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by +theorem getLsb_mul (x y : BitVec w) (i : Fin w) : + (x * y).getLsb i = Bool.xor (x.getLsb i && y.getLsb i) ((mulAddAccumulator 0 x y i).getLsb 0) := by sorry +theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) : + x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by + sorry @[simp] theorem zeroExtend'_rfl (x : BitVec w) (h : w ≤ w := by rfl) : x.zeroExtend' h = x := rfl -@[simp] -theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by +@[simp] theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by simp [truncate, zeroExtend] intro h' have h_eq : w = v := Nat.le_antisymm h h' @@ -241,46 +306,4 @@ theorem mul_eq_mulAdd (x y : BitVec w) : x * y = (mulAdd 0 x y).truncate w := by simp [mulAdd_spec] -@[simp] -theorem extractLsb'_zero (x : BitVec w) : extractLsb' 0 n x = truncate n x := by - simp [extractLsb'] - -@[simp] -theorem extractLsb'_succ_concat : extractLsb' (start+1) n (concat x a) = extractLsb' start n x := by - simp [extractLsb'] - sorry - --- theorem mulAdd_eq - -theorem mul_eq_mulAdd (x y : BitVec w) : - x * y = (mulAdd 0 x y).truncate _ := by - suffices ∀ {v w} (x : BitVec (w+v)) (y : BitVec w) (z : BitVec v), - x * (y ++ z) = (mulAdd (x*z) x y).truncate _ - by simpa using @this 0 w x y 0 - induction w - case zero => - sorry - case succ w ih => - have ⟨x, x₀, hx⟩ : ∃ (x' : BitVec w) (x₀ : Bool), x = BitVec.concat x' x₀ := sorry - have ⟨y, y₀, hy⟩ : ∃ (y' : BitVec w) (y₀ : Bool), y = BitVec.concat y' y₀ := sorry - subst hx hy - cases y₀ <;> simp [mulAdd] - · simp [extractLsb'] - rw [show 0#w = 0 from rfl, ← ih] - · sorry - -def mulC (x y : BitVec w) : BitVec w := - go _ - where go (acc : BitVec w) (x y : BitVec w) : BitVec w - --- def boothMul (x y : BitVec w) : BitVec w := --- let a : BitVec (w+w+1) := x ++ (0 : BitVec (w+1)) --- let s : BitVec (w+w+1) := -x ++ (0 : BitVec (w+1)) --- let p : BitVec (w+w+1) := (0 : BitVec w) ++ (y : BitVec w) ++ (0 : BitVec 1) --- go a s p w --- where --- go (a s p : BitVec (w+w+1)) : Nat → BitVec w --- | 0 => p --- | n+1 => - end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 13f4bf8a8a95..7c151fd24536 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -704,7 +704,7 @@ theorem msb_append {x : BitVec w} {y : BitVec v} : simp only [getLsb_append, cond_eq_if] split <;> simp [*] -theorem BitVec.shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) : +theorem shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) : (x >>> n) >>> m = x >>> (n + m) := by ext i simp [Nat.add_assoc n m i]