Skip to content

Commit

Permalink
chore: ancilla cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bollu committed Sep 4, 2024
1 parent 4eac2ff commit 01460c0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 87 deletions.
130 changes: 43 additions & 87 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -525,11 +525,6 @@ Note that `DivModState` manipulates thw widths of the remainder and the dividend
-/
structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w)
(qr : DivModState w) : Prop where
-- TODO: make `hwr, hwn` corollaries of `hwrn`.
/-- The remainder width is at most `w`. -/
hwr : wr ≤ w
/-- The dividend width is at most `w`. -/
hwn : wn ≤ w
/-- The sum of widths of the dividend and remainder is `w`. -/
hwrn : wr + wn = w
/-- The divisor is positive. -/
Expand Down Expand Up @@ -560,10 +555,8 @@ def DivModState.init (w : Nat) : DivModState w := {
}

/-- Make an initial state of the DivModState, for a given choice of `n, d, q, r`. -/
def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) :
def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) :
DivModState.Lawful w 0 w n d (DivModState.init w) := {
hwr := by omega,
hwn := by omega,
hwrn := by omega,
hd := by assumption
hrd := by simp [BitVec.lt_def] at hd ⊢; assumption
Expand Down Expand Up @@ -606,7 +599,7 @@ input bit to perform shift subtraction on, and thus we need `0 < wn`.
-/
structure DivModState.LawfulShiftSubtract (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w)
extends DivModState.Lawful w wr wn n d qr : Type where
/-- we can only call this function legally if we have dividend bits. -/
/-- Only perform a round of shift-subtract if we have dividend bits. -/
hwn_lt : 0 < wn


Expand All @@ -624,8 +617,6 @@ then the div rem input can be converted into a shift subtract input
to run a round of the shift subtracter. -/
def DivModState.Lawful.toLawfulShiftSubtract {qr : DivModState w}
(h : qr.Lawful w wr (wn + 1) n d) : qr.LawfulShiftSubtract wr (wn + 1) n d where
hwr := h.hwr
hwn := h.hwn
hwrn := by have := h.hwrn; omega
hd := h.hd
hrd := h.hrd
Expand All @@ -646,28 +637,13 @@ private theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two {n b k w : Nat}
have : 2^(k + 1) ≤ 2 ^w := Nat.pow_le_pow_of_le_right (by decide) (by assumption)
omega

/--
This is used when proving the correctness of the divison algorithm,
where we know that `r < d`.
We then want to show that `r <<< 1 | b - d < d` as the loop invariant.
In arithmethic, this is the same as showing that
`r * 2 + 1 - d < d`, which this theorem establishes.
-/
private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) :
2 * a + y - x < x := by omega

/--
If `n : Bitvec w` has only the low `k < w` bits set,
then `(n <<< 1 | b)` does not overflow, and we can compute its value
as a multiply and add.
-/
theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq {w : Nat}
{r : BitVec w}
{b : Bool}
(hk : k < w)
(hr : r.toNat < 2 ^ k) :
(r <<< 1 ||| zeroExtend w (ofBool b)).toNat =
(r.toNat * 2 + b.toNat) := by
/-- `x.shiftConcat b` does not overflow if `x < 2^k` for `k < w`, and so
`x.shiftConcat b |>.toNat = x.toNat * 2 + b.toNat`. -/
theorem toNat_shiftConcat_eq_of_lt_of_lt_two_pow {x : BitVec w} {b : Bool} {k : Nat}
(hk : k < w) (hx : x.toNat < 2 ^ k) :
(x.shiftConcat b).toNat = x.toNat * 2 + b.toNat := by
simp only [shiftConcat]
have : b.toNat = if b then 1 else 0 := by rcases b <;> rfl
rw [this]
have hk' : 2^k < 2^w := by
Expand All @@ -682,7 +658,7 @@ theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq {w : Nat}
rw [Nat.mod_eq_of_lt]
· rcases b with rfl | rfl <;> simp
· apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two
· exact hr
· exact hx
· rcases b <;> decide
· assumption
· ext i
Expand All @@ -692,56 +668,26 @@ theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq {w : Nat}
intros hi _ hi'
omega

theorem toNat_shiftConcat_eq_of_lt_of_lt_two_two {x : BitVec w} {b : Bool} {k : Nat}
(hk : k < w) (hx : x.toNat < 2 ^ k) :
(shiftConcat x b).toNat = x.toNat * 2 + b.toNat := by
simp only [shiftConcat]
rw [toNat_shiftLeft_or_zeroExtend_ofBool_eq (k := k)]
· omega
theorem toNat_shiftConcat_lt_of_lt_of_lt_two_pow {x : BitVec w} {b : Bool} {k : Nat}
(hk : k < w) (hx : x.toNat < 2 ^ k) :
(x.shiftConcat b).toNat < 2 ^ (k + 1) := by
rw [toNat_shiftConcat_eq_of_lt_of_lt_two_pow hk hx]
apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hx
· cases b <;> decide
· omega

@[bv_toNat]
theorem toNat_shiftConcat {x : BitVec w} {b : Bool} : (shiftConcat x b).toNat =
theorem toNat_shiftConcat {x : BitVec w} {b : Bool} : (x.shiftConcat b).toNat =
(x.toNat <<< 1 + b.toNat) % 2 ^ w := by
simp only [shiftConcat]
rw [← add_eq_or_of_and_eq_zero]
rw [← add_eq_or_of_and_eq_zero] -- Due to `add_eq_or_of_and_eq_zero`, this must live in `Bitblast`.
· simp
· ext i
simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and,
getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true',
decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp]
omega

theorem toNat_shiftConcat_lt {x : BitVec w} {b : Bool} {k : Nat}
(hk : k < w) (hx : x.toNat < 2 ^ k) :
(shiftConcat x b).toNat < 2 ^ (k + 1) := by
rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two hk hx]
apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hx
· rcases b with rfl | rfl <;> decide
· omega

@[simp]
theorem getLsb_shiftConcat {x : BitVec w} {b : Bool} {i : Nat} :
(shiftConcat x b).getLsb i =
((decide (i < w) && !decide (i < 1) && x.getLsb (i - 1)) ||
decide (i < w) && (decide (i = 0) && b)) := by
simp [shiftConcat]

theorem shiftRight_sub_one_eq_shiftConcat_getLsb_of_lt {n : BitVec w} (hwn : 0 < wn) :
n >>> (wn - 1) = (n >>> wn).shiftConcat (n.getLsb (wn - 1)) := by
-- rw [shiftConcat]
-- have hwn_lt := h.hwn_lt
ext i
simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and,
getLsb_zeroExtend, getLsb_ofBool]
by_cases (i : Nat) < 1
case pos h =>
simp [show (i : Nat) = 0 by omega]
omega
case neg h =>
have hi : (i : Nat) ≠ 0 := by omega
simp [shiftConcat, h, hi, show wn - 1 + ↑i = wn + (↑i - 1) by omega]

/-! ### Division shift subtractor -/

/--
Expand Down Expand Up @@ -769,7 +715,7 @@ theorem DivModState.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb
obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).shiftConcat (n.getLsb (wn - 1))).toNat := by
simp [hn]
simp only [toNat_ushiftRight] at hn
rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two (k := w - wn)] at hn
rw [toNat_shiftConcat_eq_of_lt_of_lt_two_pow (k := w - wn)] at hn
· rw [hn]
rw [toNat_ushiftRight]
· have := h.hwn_lt
Expand All @@ -779,53 +725,63 @@ theorem DivModState.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb
have := h.hwrn
omega

/--
This is used when proving the correctness of the divison algorithm,
where we know that `r < d`.
We then want to show that `((r.shiftConcat b) - d) < d` as the loop invariant.
In arithmethic, this is the same as showing that
`r * 2 + 1 - d < d`, which this theorem establishes.
-/
private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) :
2 * a + y - x < x := by omega

/-- We show that the output of `divSubtractShift` is lawful, which tells us that it
obeys the division equation. -/
def divSubtractShiftProof (qr : DivModState w) (h : qr.LawfulShiftSubtract wr wn n d) :
DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by
simp only [divSubtractShift, decide_eq_true_eq]
-- We add these hypotheses for `omega` to find them later.
have ⟨⟨_hwr, _hwn, hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h
have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h
have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc]
by_cases rltd : shiftConcat qr.r (n.getLsb (wn - 1)) < d
· simp only [rltd, ↓reduceIte]
constructor <;> try bv_omega
case pos.hrwr => apply toNat_shiftConcat_lt <;> omega
case pos.hqwr => apply toNat_shiftConcat_lt <;> omega
case pos.hrwr => apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega
case pos.hqwr => apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega
case pos.hdiv =>
simp only [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h
, toNat_shiftConcat_eq_of_lt_of_lt_two_two
, toNat_shiftConcat_eq_of_lt_of_lt_two_pow
(hk := qr.wr_lt_w h)
(hx := h.hrwr)
, h.hdiv
, toNat_shiftConcat_eq_of_lt_of_lt_two_two (qr.wr_lt_w h) h.hqwr, Nat.add_mul, toNat_false, Nat.add_zero
, toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hqwr, Nat.add_mul, toNat_false, Nat.add_zero
, this]
omega
· simp only [rltd, ↓reduceIte]
constructor <;> try bv_omega
case neg.hrd =>
simp only [lt_def, Nat.not_lt] at rltd
rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd,
toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr),
toNat_shiftConcat_eq_of_lt_of_lt_two_pow (hk := qr.wr_lt_w h) (hx := h.hrwr),
Nat.mul_comm]
apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega
case neg.hrwr =>
simp only
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsb (wn - 1))) :=
BitVec.le_iff_not_lt.mp rltd
have hr' : ((qr.r.shiftConcat (n.getLsb (wn - 1)))).toNat < 2 ^ (wr + 1) := by
apply toNat_shiftConcat_lt <;> bv_omega
apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> bv_omega
rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le hdr']
omega
case neg.hqwr =>
apply toNat_shiftConcat_lt <;> omega
apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega
case neg.hdiv =>
have rltd' := (BitVec.le_iff_not_lt.mp rltd)
simp only [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h]
simp only [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd']
simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr)]
simp only [BitVec.le_def, toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr)] at rltd'
simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hqwr),
simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_pow (hk := qr.wr_lt_w h) (hx := h.hrwr)]
simp only [BitVec.le_def, toNat_shiftConcat_eq_of_lt_of_lt_two_pow (hk := qr.wr_lt_w h) (hx := h.hrwr)] at rltd'
simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_pow (hk := qr.wr_lt_w h) (hx := h.hqwr),
h.hdiv, Nat.mul_add]
bv_omega

Expand Down Expand Up @@ -887,18 +843,18 @@ theorem divRec_correct {n d : BitVec w} (qr : DivModState w)
exact h.toLawfulShiftSubtract

/-- The result of `udiv` agrees with the result of the division recurrence. -/
theorem udiv_eq_divRec (hw : 0 < w) (hd : 0 < d) :
theorem udiv_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
n.udiv d = out.q := by
have := DivModState.Lawful.init w n d hw hd
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.udiv_eq_of_lawful_zero this

/-- The result of `umod` agrees with the result of the division recurrence. -/
theorem umod_eq_divRec (hw : 0 < w) (hd : 0 < d) :
theorem umod_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
n.umod d = out.r := by
have := DivModState.Lawful.init w n d hw hd
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.umod_eq_of_lawful_zero this

Expand Down
19 changes: 19 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,25 @@ theorem getLsb_concat (x : BitVec w) (b : Bool) (i : Nat) :

/-! ### shiftConcat -/

@[simp]
theorem getLsb_shiftConcat {x : BitVec w} {b : Bool} {i : Nat} :
(x.shiftConcat b).getLsb i =
((decide (i < w) && !decide (i < 1) && x.getLsb (i - 1)) ||
decide (i < w) && (decide (i = 0) && b)) := by
simp [shiftConcat]

theorem shiftRight_sub_one_eq_shiftConcat_getLsb_of_lt {n : BitVec w} (hwn : 0 < wn) :
n >>> (wn - 1) = (n >>> wn).shiftConcat (n.getLsb (wn - 1)) := by
ext i
simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and,
getLsb_zeroExtend, getLsb_ofBool]
by_cases (i : Nat) < 1
case pos h =>
simp [show (i : Nat) = 0 by omega]
omega
case neg h =>
have hi : (i : Nat) ≠ 0 := by omega
simp [shiftConcat, h, hi, show wn - 1 + ↑i = wn + (↑i - 1) by omega]

/-! ### add -/

Expand Down

0 comments on commit 01460c0

Please sign in to comment.