diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index b8ff07d5dc21..7f9b4969b7e6 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -446,7 +446,7 @@ It's easy to cook up such examples, by chosing `(q, r)` for a fixed `(d, n)` such that `(d * q + r)` overflows. This tells us that the division algorithm must have more restrictions that just the ones -we have for natural numbers. These restrictions are captured in `DivRemInput.Lawful`, +we have for natural numbers. These restrictions are captured in `DivRemState.Lawful`, which captures the relationship necessary between `n, d, q, r`. The key idea is to state the relationship in terms of the `{n, d, q, r}.toNat` values, and then prove that the relationship holds for the bitvector values. @@ -469,46 +469,63 @@ private theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < · exact hy /-- If the division equation `d.toNat * q.toNat + r.toNat = n.toNat` holds, -then `n.udiv d = q` and `n.umod d = rm` -/ -theorem udiv_umod_characterized_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d) +then `n.udiv d = q`. -/ +theorem udiv_eq_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d) (hrd : r < d) (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : - (n.udiv d = q ∧ n.umod d = r) := by - constructor - · apply BitVec.eq_of_toNat_eq - rw [toNat_udiv] - replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by - simp [hdqnr] - rw [Nat.div_add_eq_left_of_lt] at hdqnr - · rw [← hdqnr] - exact mul_div_right q.toNat hd - · exact Nat.dvd_mul_right d.toNat q.toNat - · exact hrd - · exact hd - · apply BitVec.eq_of_toNat_eq - rw [toNat_umod] - replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by - simp [hdqnr] - rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr - simp only [Nat.zero_add, mod_mod] at hdqnr - replace hrd : r.toNat < d.toNat := by - rw [BitVec.lt_def] at hrd - exact hrd -- TODO: golf - rw [Nat.mod_eq_of_lt hrd] at hdqnr + n.udiv d = q := by + apply BitVec.eq_of_toNat_eq + rw [toNat_udiv] + replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by simp [hdqnr] + rw [Nat.div_add_eq_left_of_lt] at hdqnr + · rw [← hdqnr] + exact mul_div_right q.toNat hd + · exact Nat.dvd_mul_right d.toNat q.toNat + · exact hrd + · exact hd -/-! ### DivRemInput -/ - -/-- Structure that maintains the input to `divrem`.-/ -structure DivRemInput (w : Nat) : Type where +/-- If the division equation `d.toNat * q.toNat + r.toNat = n.toNat` holds, +then `n.umod d = r` -/ +theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d) + (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : + n.umod d = r := by + apply BitVec.eq_of_toNat_eq + rw [toNat_umod] + replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by + simp [hdqnr] + rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr + simp only [Nat.zero_add, mod_mod] at hdqnr + replace hrd : r.toNat < d.toNat := by + rw [BitVec.lt_def] at hrd + exact hrd -- TODO: golf + rw [Nat.mod_eq_of_lt hrd] at hdqnr + simp [hdqnr] + +/-! ### DivRemState -/ + +/-- Structure that maintains the state of recursive `divrem` calls. -/ +structure DivRemState (w : Nat) : Type where /-- The current quotient. -/ q : BitVec w /-- The current remainder. -/ r : BitVec w -/-- A lawful instance of `DivRemInput w`. -/ -structure DivRemInput.Lawful (w wr wn : Nat) (n d : BitVec w) - (qr : DivRemInput w) : Prop where +/-- A `DivRemState` is lawful if the remainder width `wr` plus the dividiend width `wn` equals `w`, +and that the bitvectors `r` and `n` have values in these bounds given these bitwidths. + +This is a proof engineering choice: An alternative world could have +`r : BitVec wr` and `n : BitVec wn`, but this required much more dependent typing coercions. + +Instead, we choose to declare all involved bitvectors as length `w`, and then prove that +the values of `r` and `n` are within the bounds of `wr` and `wn` respectively. + +Note that `DivRemState` manipulates thw widths of the remainder and the dividend, + +-/ +structure DivRemState.Lawful (w wr wn : Nat) (n d : BitVec w) + (qr : DivRemState w) : Prop where + -- TODO: make `hwr, hwn` corollaries of `hwrn`. /-- The remainder width is at most `w`. -/ hwr : wr ≤ w /-- The dividend width is at most `w`. -/ @@ -526,9 +543,9 @@ structure DivRemInput.Lawful (w wr wn : Nat) (n d : BitVec w) /-- The low n bits of `n` obey the fundamental division equation. -/ hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat -/-- A lawful DivRemInput implies `w > 0`. -/ -def DivRemInput.Lawful.hw {qr : DivRemInput w} - {h : DivRemInput.Lawful w wr wn n d qr} : 0 < w := by +/-- A lawful DivRemState implies `w > 0`. -/ +def DivRemState.Lawful.hw {qr : DivRemState w} + {h : DivRemState.Lawful w wr wn n d qr} : 0 < w := by have hd := h.hd rcases w with rfl | w · have hcontra : d = 0#0 := by apply Subsingleton.elim @@ -537,70 +554,83 @@ def DivRemInput.Lawful.hw {qr : DivRemInput w} · omega /-- An initial value with both `q, r = 0`. -/ -def DivRemInput.init (w : Nat) : DivRemInput w := { +def DivRemState.init (w : Nat) : DivRemState w := { q := 0#w r := 0#w } -/-- Make an initial state of the DivRemInput, for a given choice of `n, d, q, r`. -/ -def DivRemInput.Lawful.init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : - DivRemInput.Lawful w 0 w n d (DivRemInput.init w) := { +/-- Make an initial state of the DivRemState, for a given choice of `n, d, q, r`. -/ +def DivRemState.Lawful.init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + DivRemState.Lawful w 0 w n d (DivRemState.init w) := { hwr := by omega, hwn := by omega, hwrn := by omega, hd := by assumption hrd := by simp [BitVec.lt_def] at hd ⊢; assumption - hrwr := by simp [DivRemInput.init], - hqwr := by simp [DivRemInput.init], + hrwr := by simp [DivRemState.init], + hqwr := by simp [DivRemState.init], hdiv := by - simp only [DivRemInput.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; + simp only [DivRemState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; rw [Nat.shiftRight_eq_div_pow] apply Nat.div_eq_of_lt n.isLt } -/-- A lawful DivRemInput gives us the `udiv/umod` relations we want. -/ -theorem udiv_urem_eqn_of_DivRemInput.Lawful - (qr : DivRemInput w) (h : DivRemInput.Lawful w w 0 n d qr): - n.udiv d = qr.q ∧ n.umod d = qr.r := by - apply udiv_umod_characterized_of_mul_add_toNat - (n := n) (d := d) (q := qr.q) (r := qr.r) - (h.hd) - (h.hrd) - (by - have hdiv := h.hdiv - simp at hdiv - omega - ) +/-- +A lawful DivRemState with a fully consumed dividend (`wn = 0`) witneses that the +quotient has been correctly computed. +-/ +theorem DivRemState.udiv_eq_of_lawful_zero {qr : DivRemState w} + (h : DivRemState.Lawful w w 0 n d qr) : + n.udiv d = qr.q := by + apply udiv_eq_of_mul_add_toNat h.hd h.hrd + have hdiv := h.hdiv + omega + +/-- +A lawful DivRemState with a fully consumed dividend (`wn = 0`) witneses that the +remainder has been correctly computed. +-/ +theorem DivRemState.umod_eq_of_lawful_zero {qr : DivRemState w} + (h : DivRemState.Lawful w w 0 n d qr) : + n.umod d = qr.r := by + apply umod_eq_of_mul_add_toNat h.hrd + have hdiv := h.hdiv + simp only [shiftRight_zero] at hdiv + exact hdiv.symm -/-! ### ShiftSubtractInput -/ +/-! ### ShiftSubtractState -/ /-- -An input that is given to the shift-subtractor. -We create a type alias for this, as the shift subtractor has a different `Lawful` instance. +Internal state of the shift-subtractor. +The state is the same as the internal state of the `divrem` algorithm, +but possesses a stronger invariant that asserts that there are dividend bits to be consumed. + +There, we make a type alias for the different `Lawful` instance. -/ -def ShiftSubtractInput w := DivRemInput w +def ShiftSubtractState w := DivRemState w /-- Forget that the input is to a shift subtractor, and make it a DivRem input. -/ -def ShiftSubtractInput.toDivRemInput (qr : ShiftSubtractInput w) : - DivRemInput w := qr +def ShiftSubtractState.toDivRemState (qr : ShiftSubtractState w) : DivRemState w := qr /-- Forget that the input is to a shift subtractor, and make it a DivRem input. -/ -def ShiftSubtractInput.ofDivRemInput (qr : DivRemInput w) : - ShiftSubtractInput w := qr +def ShiftSubtractState.ofDivRemState (qr : DivRemState w) : + ShiftSubtractState w := qr /-- The input to the shift subtractor is a legal input to `divrem`, and we also need to have an input bit to perform shift subtraction on, and thus we need `0 < wn`. + +Refactor into `LawfulShiftSubtractState`, `LawfulDivRemState`. -/ -structure ShiftSubtractInput.Lawful (w wr wn : Nat) (n d : BitVec w) (qr : ShiftSubtractInput w) - extends DivRemInput.Lawful w wr wn n d qr : Type where +structure ShiftSubtractState.Lawful (w wr wn : Nat) (n d : BitVec w) (qr : ShiftSubtractState w) + extends DivRemState.Lawful w wr wn n d qr : Type where /-- we can only call this function legally if we have dividend bits. -/ hwn_lt : 0 < wn /-- In the shift subtract input, we have one more bit to spare, so we do not overflow. -/ -def ShiftSubtractInput.wr_add_one_le_w {qr : ShiftSubtractInput w} +def ShiftSubtractState.wr_add_one_le_w {qr : ShiftSubtractState w} (h : qr.Lawful wr wn n d) : wr + 1 ≤ w := by have hwrn := h.hwrn have hwn_lt := h.hwn_lt @@ -609,15 +639,17 @@ def ShiftSubtractInput.wr_add_one_le_w {qr : ShiftSubtractInput w} /-- In the shift subtract input, we have one more bit to spare, so we still have remainder bits to be computed. Thus, `wr < w`. + +TODO: delete the others (wr_add_one_le_w, wr_le_wr_sub_one), and then `omega` should infer the others. -/ -def ShiftSubtractInput.wr_lt_w {qr : ShiftSubtractInput w} (h : qr.Lawful wr wn n d) : wr < w := by +def ShiftSubtractState.wr_lt_w {qr : ShiftSubtractState w} (h : qr.Lawful wr wn n d) : wr < w := by have hwr := qr.wr_add_one_le_w h omega /-- In the shift subtract input, we have one more bit to spare, so we do not overflow. -/ -def ShiftSubtractInput.wr_le_wr_sub_one {qr : ShiftSubtractInput w} +def ShiftSubtractState.wr_le_wr_sub_one {qr : ShiftSubtractState w} (h : qr.Lawful wr wn n d) : wr ≤ w - 1 := by have hw := h.hw have hwrn := h.hwrn @@ -627,43 +659,18 @@ def ShiftSubtractInput.wr_le_wr_sub_one {qr : ShiftSubtractInput w} /-- If we have extra bits to spare in `n`, then the div rem input can be converted into a shift subtract input to run a round of the shift subtracter. -/ -def DivRemInput.Lawful.toShiftSubtractInputLawful (qr : DivRemInput w) - (h : DivRemInput.Lawful w wr (wn + 1) n d qr) : (ShiftSubtractInput.ofDivRemInput qr).Lawful wr (wn + 1) n d := { - hwr := h.hwr, - hwn := h.hwn, - hwrn := by have := h.hwrn; omega, - hd := h.hd, - hrd := h.hrd, - hrwr := h.hrwr, - hqwr := h.hqwr, - hdiv := h.hdiv, - hwn_lt := by omega - } - -private theorem shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb - {x : BitVec w} {k : Nat} (hk' : 0 < k) : - x >>> (k - 1) = (((x >>> k) <<< 1) ||| ((BitVec.ofBool (x.getLsb (k - 1))).zeroExtend w)) := by - ext i - simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, - getLsb_zeroExtend, getLsb_ofBool] - by_cases (i : Nat) < 1 - case pos h => - have hi : (i : Nat) = 0 := by omega - simp [hi] - case neg h => - have hi : (i : Nat) ≠ 0 := by omega - simp only [hi, decide_False, Bool.false_and, Bool.or_false, - show ¬ (i : Nat) < 1 by omega] - congr 1 - omega - -theorem ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb - {qr : ShiftSubtractInput w} (h : qr.Lawful wr wn n d): - n >>> (wn - 1) = (n >>> wn).shiftConcat (n.getLsb (wn - 1)) := by - rw [shiftConcat] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb] - have hwn_lt := h.hwn_lt - omega +def DivRemState.Lawful.toShiftSubtractStateLawful (qr : DivRemState w) + (h : DivRemState.Lawful w wr (wn + 1) n d qr) : + (ShiftSubtractState.ofDivRemState qr).Lawful wr (wn + 1) n d where + hwr := h.hwr + hwn := h.hwn + hwrn := by have := h.hwrn; omega + hd := h.hd + hrd := h.hrd + hrwr := h.hrwr + hqwr := h.hqwr + hdiv := h.hdiv + hwn_lt := by omega /-! ### shiftConcat -/ @@ -731,6 +738,18 @@ theorem toNat_shiftConcat_eq_of_lt_of_lt_two_two {x : BitVec w} {b : Bool} {k : · omega · omega +@[bv_toNat] +theorem toNat_shiftConcat {x : BitVec w} {b : Bool} : (shiftConcat x b).toNat = + (x.toNat <<< 1 + b.toNat) % 2 ^ w := by + simp only [shiftConcat] + rw [← add_eq_or_of_and_eq_zero] + · simp + · ext i + simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', + decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] + omega + theorem toNat_shiftConcat_lt {x : BitVec w} {b : Bool} {k : Nat} (hk : k < w) (hx : x.toNat < 2 ^ k) : (shiftConcat x b).toNat < 2 ^ (k + 1) := by @@ -739,11 +758,32 @@ theorem toNat_shiftConcat_lt {x : BitVec w} {b : Bool} {k : Nat} · rcases b with rfl | rfl <;> decide · omega +-- TODO: find out what hypotheses I need from `qr.Lawful` so it stands alone. +theorem ShiftSubtractState.shiftRight_sub_one_eq_shiftConcat + {qr : ShiftSubtractState w} (h : qr.Lawful wr wn n d): + n >>> (wn - 1) = (n >>> wn).shiftConcat (n.getLsb (wn - 1)) := by + rw [shiftConcat] + have hwn_lt := h.hwn_lt + ext i + simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool] + by_cases (i : Nat) < 1 + case pos h => + have hi : (i : Nat) = 0 := by omega + simp [hi] + case neg h => + have hi : (i : Nat) ≠ 0 := by omega + simp only [hi, decide_False, Bool.false_and, Bool.or_false, + show ¬ (i : Nat) < 1 by omega] + congr 1 + omega + +-- TODO: find out what hypotheses I need from `qr.Lawful` so it stands alone. /-- The value of shifting by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/ -theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb - (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d): +theorem ShiftSubtractState.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb + (qr : ShiftSubtractState w) (h : qr.Lawful wr wn n d): n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsb (wn - 1)).toNat := by - have hn := ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb (qr := qr) h + have hn := ShiftSubtractState.shiftRight_sub_one_eq_shiftConcat (qr := qr) h obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).shiftConcat (n.getLsb (wn - 1))).toNat := by simp [hn] simp only [toNat_ushiftRight] at hn @@ -762,137 +802,69 @@ theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb /-- One round of the division algorithm, that tries to perform a subtract shift. Note that this is only called when `r.msb = false`, so we will not overflow. -This means that `r'.toNat = r.toNat *2 + q.toNat`. -/ -def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : ShiftSubtractInput w) : - DivRemInput w := - let r' := shiftConcat qr.r (n.getLsb (wn - 1)) - let rltd : Bool := r' < d -- true if r' < d. In this case, we don't have a quotient bit. - let q := qr.q.shiftConcat !rltd -- if r ≥ d, then we have a quotient bit. - if rltd +def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : ShiftSubtractState w) : + DivRemState w := + let r' := shiftConcat qr.r (n.getLsb (wn - 1)) -- if r ≥ d, then we have a quotient bit. + if r' < d then { - q := q, + q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit. r := r' } else { - q := q, - r := r' - d + q := qr.q.shiftConcat true, -- If `r' ≥ d`, then we have a quotient bit. + r := r' - d -- we subtract to maintain the invariant that `r < d`. } /-- We show that the output of `divSubtractShift` is lawful, which tells us that it obeys the division equation. -/ -def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d) : - DivRemInput.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by +def divSubtractShiftProof (qr : ShiftSubtractState w) (h : qr.Lawful wr wn n d) : + DivRemState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by simp only [divSubtractShift, decide_eq_true_eq] + -- We add these hypotheses for `omega` to find them later. + have ⟨⟨hp, hq, hr, hs, ht, hu, hw, hx⟩, hb⟩ := h + have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc] by_cases rltd : shiftConcat qr.r (n.getLsb (wn - 1)) < d · simp only [rltd, ↓reduceIte] - constructor - case pos.hwr => - have := h.hwr - have := qr.wr_add_one_le_w h - omega - case pos.hwrn => - have := h.hwrn - have := qr.wr_add_one_le_w h - omega - case pos.hd => - have := h.hd - assumption - case pos.hrd => - simp only - simpa using rltd - case pos.hrwr => - simp [rltd] - apply toNat_shiftConcat_lt - · exact qr.wr_add_one_le_w h - · exact h.hrwr - case pos.hqwr => - apply toNat_shiftConcat_lt - · exact qr.wr_add_one_le_w h - · exact h.hqwr + constructor <;> try bv_omega + case pos.hrwr => apply toNat_shiftConcat_lt <;> omega + case pos.hqwr => apply toNat_shiftConcat_lt <;> omega case pos.hdiv => - rw [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h] - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two (x := qr.r) (k := wr) (b := (n.getLsb (wn - 1))) + simp only [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h + , toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) - (hx := h.hrwr)] - rw [h.hdiv] - simp only [decide_True, Bool.not_true] - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two (qr.wr_lt_w h) h.hqwr, Nat.add_mul, toNat_false, Nat.add_zero] - simp [show d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 by - rw [Nat.mul_assoc] - ] + (hx := h.hrwr) + , h.hdiv + , toNat_shiftConcat_eq_of_lt_of_lt_two_two (qr.wr_lt_w h) h.hqwr, Nat.add_mul, toNat_false, Nat.add_zero + , this] omega - case pos.hwn => - have := h.hwn_lt - have := h.hwn - omega - · simp [rltd] - constructor - case neg.hwr => - have := h.hwr - have := qr.wr_add_one_le_w h - omega - case neg.hwrn => - have := h.hwrn - have := qr.wr_add_one_le_w h - omega - case neg.hd => - have := h.hd - assumption + · simp only [rltd, ↓reduceIte] + constructor <;> try bv_omega case neg.hrd => - simp only simp only [lt_def, Nat.not_lt] at rltd - have hr := h.hrd - have hr' : qr.r < d := by simp only [lt_def]; exact hr - rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd] - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two (x := qr.r) - (k := wr) - (hk := qr.wr_lt_w h) -- TODO: refactor wr_lt_w to be on `h`. - (hx := h.hrwr)] - rw [Nat.mul_comm] - apply two_mul_add_sub_lt_of_lt_of_lt_two - · exact hr' - · apply Bool.toNat_lt + rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd, + toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr), + Nat.mul_comm] + apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega case neg.hrwr => simp only have hdr' : d ≤ (qr.r.shiftConcat (n.getLsb (wn - 1))) := BitVec.le_iff_not_lt.mp rltd have hr' : ((qr.r.shiftConcat (n.getLsb (wn - 1)))).toNat < 2 ^ (wr + 1) := by - apply toNat_shiftConcat_lt - · exact qr.wr_add_one_le_w h - · exact h.hrwr + apply toNat_shiftConcat_lt <;> bv_omega rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le hdr'] omega case neg.hqwr => - apply toNat_shiftConcat_lt - · exact qr.wr_add_one_le_w h - · exact h.hqwr + apply toNat_shiftConcat_lt <;> omega case neg.hdiv => - simp only - rw [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h] have rltd' := (BitVec.le_iff_not_lt.mp rltd) - rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le - rltd'] - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two - (hk := qr.wr_lt_w h) - (hx := h.hrwr)] + simp only [qr.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb h] + simp only [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd'] + simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr)] simp only [BitVec.le_def] at rltd' - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two (x := qr.r) (k := wr) (b := (n.getLsb (wn - 1))) - (hk := qr.wr_lt_w h) - (hx := h.hrwr)] at rltd' - rw [toNat_shiftConcat_eq_of_lt_of_lt_two_two - (hk := qr.wr_lt_w h) - (hx := h.hqwr)] - rw [h.hdiv, Nat.mul_add] - simp only [toNat_true, Nat.mul_one] - rw [Nat.add_mul] - simp [show d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 by - rw [Nat.mul_assoc] - ] - omega - case neg.hwn => - have := h.hwn_lt - have := h.hwn - omega + simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hrwr)] at rltd' + simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_two (hk := qr.wr_lt_w h) (hx := h.hqwr)] + simp only [h.hdiv, Nat.mul_add] + bv_omega /-! ### Core divsion algorithm @@ -904,7 +876,7 @@ We have three widths at play: We have the invariant that wn + wr = w. -See that when `divRec'` is called with a `DivRemInput.Lawful h`, we know that: +See that when `divRec` is called with a `DivRemState.Lawful h`, we know that: - r < [2^wr = 2^(w - wn)] which allows us to safely shift left, since it is of length n. @@ -915,56 +887,70 @@ See that when `divRec'` is called with a `DivRemInput.Lawful h`, we know that: So, the remainder is morally of length `w - wn`. - d > 0 - r < d. + + + +TODO: link to previous definition for proof engineering notes. -/ -/-- Core divison recurrence. -/ -def divRec' (w wr wn : Nat) (n d : BitVec w) (qr : DivRemInput w) : - DivRemInput w := + +/-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/ +def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivRemState w) : + DivRemState w := match wn with | 0 => qr | wn + 1 => - divRec' w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) (ShiftSubtractInput.ofDivRemInput qr) + divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) (ShiftSubtractState.ofDivRemState qr) @[simp] -theorem divRec'_zero (qr : DivRemInput w) : - divRec' w w 0 n d qr = qr := rfl +theorem divRec_zero (qr : DivRemState w) : + divRec w w 0 n d qr = qr := rfl @[simp] -theorem divRec'_succ (wn : Nat) (qr : DivRemInput w) : - divRec' w wr (wn + 1) n d qr = - divRec' w (wr + 1) wn n d - (divSubtractShift n d (wn + 1) (ShiftSubtractInput.ofDivRemInput qr)) := rfl +theorem divRec_succ (wn : Nat) (qr : DivRemState w) : + divRec w wr (wn + 1) n d qr = + divRec w (wr + 1) wn n d + (divSubtractShift n d (wn + 1) (ShiftSubtractState.ofDivRemState qr)) := rfl -theorem divRec'_correct {n d : BitVec w} (qr : DivRemInput w) - (h : DivRemInput.Lawful w wr wn n d qr) : DivRemInput.Lawful w w 0 n d (divRec' w wr wn n d qr) := by +theorem divRec_correct {n d : BitVec w} (qr : DivRemState w) + (h : DivRemState.Lawful w wr wn n d qr) : DivRemState.Lawful w w 0 n d (divRec w wr wn n d qr) := by induction wn generalizing wr qr case zero => - unfold divRec' + unfold divRec simp [← h.hwrn, h] case succ wn' ih => - simp only [divRec'] + simp only [divRec] apply ih apply divSubtractShiftProof (w := w) (wr := wr) (wn := wn' + 1) - exact DivRemInput.Lawful.toShiftSubtractInputLawful qr h - -theorem divRec'_eq (hw : 0 < w) (hd : 0 < d) : - let out := divRec' w 0 w n d (DivRemInput.init w) - n.udiv d = out.q ∧ n.umod d = out.r := by - have := DivRemInput.Lawful.init w n d hw hd - have := divRec'_correct (DivRemInput.init w) this - exact udiv_urem_eqn_of_DivRemInput.Lawful _ this + exact DivRemState.Lawful.toShiftSubtractStateLawful qr h + +/-- The result of `udiv` agrees with the result of the division recurrence. -/ +theorem udiv_eq_divRec (hw : 0 < w) (hd : 0 < d) : + let out := divRec w 0 w n d (DivRemState.init w) + n.udiv d = out.q := by + have := DivRemState.Lawful.init w n d hw hd + have := divRec_correct (DivRemState.init w) this + apply DivRemState.udiv_eq_of_lawful_zero this + +/-- The result of `umod` agrees with the result of the division recurrence. -/ +theorem umod_eq_divRec (hw : 0 < w) (hd : 0 < d) : + let out := divRec w 0 w n d (DivRemState.init w) + n.umod d = out.r := by + have := DivRemState.Lawful.init w n d hw hd + have := divRec_correct (DivRemState.init w) this + apply DivRemState.umod_eq_of_lawful_zero this @[simp] -theorem divRec'_succ' (wn : Nat) (qr : DivRemInput w) : - divRec' w wr (wn + 1) n d qr = +theorem divRec_succ' (wn : Nat) (qr : DivRemState w) : + divRec w wr (wn + 1) n d qr = let r' := shiftConcat qr.r (n.getLsb wn) - let rltd : Bool := r' < d -- true if r' < d. In this case, we don't have a quotient bit. - let q := qr.q.shiftConcat !rltd -- if r ≥ d, then we have a quotient bit. - let input : DivRemInput w := - if rltd then ⟨q, r'⟩ else ⟨q, r' - d⟩ - divRec' w (wr + 1) wn n d input := by rfl + let input : DivRemState w := + if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩ + divRec w (wr + 1) wn n d input := by + simp only [divRec_succ, divSubtractShift, Nat.add_one_sub_one, decide_eq_true_eq] + congr /- ### Arithmetic shift right (sshiftRight) recurrence -/ diff --git a/src/Init/Data/Bool.lean b/src/Init/Data/Bool.lean index 3eb9b4c1cf7f..b3ea62c0d910 100644 --- a/src/Init/Data/Bool.lean +++ b/src/Init/Data/Bool.lean @@ -353,9 +353,9 @@ theorem and_or_inj_left_iff : /-- convert a `Bool` to a `Nat`, `false -> 0`, `true -> 1` -/ def toNat (b : Bool) : Nat := cond b 1 0 -@[simp] theorem toNat_false : false.toNat = 0 := rfl +@[simp, bv_toNat] theorem toNat_false : false.toNat = 0 := rfl -@[simp] theorem toNat_true : true.toNat = 1 := rfl +@[simp, bv_toNat] theorem toNat_true : true.toNat = 1 := rfl theorem toNat_le (c : Bool) : c.toNat ≤ 1 := by cases c <;> trivial @@ -363,6 +363,7 @@ theorem toNat_le (c : Bool) : c.toNat ≤ 1 := by @[deprecated toNat_le (since := "2024-02-23")] abbrev toNat_le_one := toNat_le +@[bv_toNat] theorem toNat_lt (b : Bool) : b.toNat < 2 := Nat.lt_succ_of_le (toNat_le _)