Skip to content

Commit

Permalink
feat: write theory in terms of toInt
Browse files Browse the repository at this point in the history
  • Loading branch information
bollu committed May 14, 2024
1 parent 7d4afe9 commit 6be4817
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 44 deletions.
96 changes: 53 additions & 43 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Init.Data.Bool
import Init.Data.BitVec.Basic
import Init.Data.Fin.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Int.Bitwise.Lemmas
import Init.Data.BitVec.Basic

namespace BitVec

Expand Down Expand Up @@ -147,6 +149,7 @@ theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) :
getLsb (x#n) i = (i < n && x.testBit i) := by
simp [getLsb, BitVec.ofNat, Fin.val_ofNat']


@[simp, deprecated toNat_ofNat] theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial

@[simp] theorem getLsb_zero : (0#w).getLsb i = false := by simp [getLsb]
Expand Down Expand Up @@ -316,6 +319,19 @@ theorem ofInt_eq_ofNat_emod {n : Nat} (i : Int) :
apply Lean.Omega.Int.pos_pow_of_pos
decide

@[simp]
theorem Int.testBit_natCast (n : Nat) : (n : Int).testBit i = n.testBit i := rfl

theorem Int.natCast_mod_natCast (n m : Nat) : (n : Int) % (m : Int) = ((n % m : Nat) : Int) := rfl

theorem Int.negSucc_mod_natCast (n m : Nat) :
(Int.negSucc m) % (n : Int) = n - ((m % n) + 1) := rfl

@[simp] theorem Int.toNat_natCast (n : Nat) : (n : Int).toNat = n := rfl

@[simp] theorem ofInt_natCast (w n : Nat) :
BitVec.ofInt w n = BitVec.ofNat w n := rfl

/-! ### zeroExtend and truncate -/

@[simp, bv_toNat] theorem toNat_zeroExtend' {m n : Nat} (p : m ≤ n) (x : BitVec m) :
Expand Down Expand Up @@ -755,49 +771,6 @@ theorem msb_eq_true_iff_toInt_lt_zero (x : BitVec w)
simp
omega

#check Int.shiftRight
theorem getLsb_sshiftRight (x : BitVec n) (s i : Nat) :
getLsb (x.sshiftRight s) i = if i ≥ n then false
else if (s + i) < n then getLsb x (s + i)
else x.msb := by
simp
by_cases hxpos : (x.toInt ≥ 0)
-- If x ≥ 0, then arithmetic = logical shift right.
case pos =>
rw [sshiftRight_eq_ushiftRight_of_pos hxpos]
rw [(msb_eq_false_iff_toInt_geq_zero x).mpr hxpos]
by_cases h₁ : s + i < n;
· simp only [ushiftRight_eq, getLsb_ushiftRight, h₁, ↓reduceIte, Bool.iff_and_self,
Bool.not_eq_true', decide_eq_false_iff_not, Nat.not_le]; omega
· simp only [ushiftRight_eq, getLsb_ushiftRight, h₁, ↓reduceIte, Bool.and_false]
apply getLsb_ge
omega
case neg =>
-- if x < 0, then msb = true
rw [(msb_eq_true_iff_toInt_lt_zero x).mpr (by omega)]
simp
by_cases h₁ : n ≤ i <;> simp [h₁]
by_cases h₂ : s + i < n;
case pos =>
simp [h₂]
rw [sshiftRight_eq]
-- the meat and potatoes case.
sorry
case neg =>
simp [h₂]
-- n ≥ i
-- s + i > n
rw [sshiftRight_eq]
rw [BitVec.ofInt_eq_ofNat_emod]
rw [BitVec.getLsb_ofNat]
simp [h₁]
sorry

-- rw [sshiftRight_eq]
-- rw [ofInt_eq_ofNat_emod]
-- rw [getLsb_ofNat]


/-! ### append -/

theorem append_def (x : BitVec v) (y : BitVec w) :
Expand Down Expand Up @@ -1212,4 +1185,41 @@ theorem toNat_intMax_eq : (intMax w).toNat = 2^w - 1 := by
(ofBoolListLE bs).getMsb i = (decide (i < bs.length) && bs.getD (bs.length - 1 - i) false) := by
simp [getMsb_eq_getLsb]


@[simp] theorem ofInt_negSucc (w n : Nat) :
BitVec.ofInt w (Int.negSucc n) = ~~~ (BitVec.ofNat w n) := by
apply BitVec.eq_of_toNat_eq
simp
sorry

@[simp] theorem getLsb_ofInt (n : Nat) (x : Int) (i : Nat) :
getLsb (BitVec.ofInt n x) i = (i < n && x.testBit i) := by
cases x
case ofNat x =>
simp
simp [getLsb_ofNat]
case negSucc x =>
simp [Int.testBit]
simp [getLsb_ofNat]
cases decide (i < n) <;> simp

@[simp] theorem toInt_sshiftRight (x : BitVec n) (i : Nat) :
(x.sshiftRight i).toInt = (x.toInt >>> i).bmod (2^n) := by
rw [sshiftRight_eq, BitVec.toInt_ofInt]

-- theorem testBit_toInt (x : BitVec w) :
-- x.toInt.testBit i = x.getLsb i := rfl

#check Int.testBit_shiftRight
theorem getLsb_sshiftRight (x : BitVec n) (s i : Nat) :
getLsb (x.sshiftRight s) i = if i ≥ n then false
else if (s + i) < n then getLsb x (s + i)
else x.msb := by

rw [sshiftRight_eq]
rw [getLsb_ofInt]
rw [Int.testBit_shiftRight]
by_cases h₁:i < n <;> simp [h₁]


end BitVec
13 changes: 13 additions & 0 deletions src/Init/Data/Int/Bitwise.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,17 @@ protected def shiftRight : Int → Nat → Int

instance : HShiftRight Int Nat Int := ⟨.shiftRight⟩

/-
### testBit
We define an operation for testing individual bits in the binary representation
of a number.
-/

-- -m = !m + 1
-- -(m + 1) = -m - 1 = !m
/-- `testBit m n` returns whether the `(n+1)` least significant bit is `1` or `0`-/
def testBit : Int → Nat → Bool
| .ofNat m, n => Nat.testBit m n
| .negSucc m, n => !(Nat.testBit m n)

end Int
27 changes: 26 additions & 1 deletion src/Init/Data/Int/Bitwise/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import Init.Data.Nat.Lemmas
import Init.Omega.Int

namespace Int

theorem shiftRight_eq (n : Int) (s : Nat) : n >>> s = Int.shiftRight n s := rfl
theorem shiftRight_ofNat (n s : Nat) : Int.ofNat n >>> s = Int.ofNat (n >>> s) := rfl
@[simp]
theorem shiftRight_ofNat (n s : Nat) : (n : Int) >>> s = Int.ofNat (n >>> s) := rfl
theorem natCast_shiftRight (n s : Nat) : ((↑n) : Int) >>> s = n >>> s := rfl

@[simp]
Expand Down Expand Up @@ -49,3 +51,26 @@ theorem shiftRight_eq_div_pow (m : Int) (n : Nat) : m >>> n = m / ((((2 : Nat) ^
@[simp]
theorem zero_shiftRight (n : Nat) : (0 : Int) >>> n = 0 := by
simp [Int.shiftRight_eq_div_pow]

@[simp] theorem zero_testBit (i : Nat) : Int.testBit 0 i = false := by
simp only [testBit, zero_shiftRight, Nat.zero_and, bne_self_eq_false, Nat.zero_testBit i]

-- @[simp] theorem testBit_zero (x : Int) : Int.testBit x 0 = decide (x % 2 = 1) := by
-- unfold testBit
-- cases x <;> simp [Nat.testBit_zero]
-- case ofNat x =>
-- omega

@[simp] theorem testBit_succ (x : Int) (i : Nat) : Int.testBit x (Nat.succ i) = testBit (x/2) i := by
unfold testBit
cases x <;> simp <;> rfl

theorem toNat_testBit (x i : Nat) :
(x.testBit i).toNat = x / 2 ^ i % 2 := by
rw [Nat.testBit_to_div_mod]
rcases Nat.mod_two_eq_zero_or_one (x / 2^i) <;> simp_all

@[simp] theorem testBit_shiftRight (x : Int) (i j : Nat) : testBit (x >>> i) j = testBit x (i+j) := by
cases x <;> simp [testBit]

end Int

0 comments on commit 6be4817

Please sign in to comment.