Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress towards GCMGMultV8 and associated BitVec cleanup #243

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
622 changes: 375 additions & 247 deletions Arm/BitVec.lean

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def polynomial_mult_aux (i : Nat) (result : BitVec (m+n))
polynomial_mult_aux (i+1) new_res op1 op2
termination_by (m - i)

/-
Ref.:
https://developer.arm.com/documentation/ddi0602/2024-09/Shared-Pseudocode/shared-functions-vector?lang=en#impl-shared.PolynomialMult.2
bits(M+N) PolynomialMult(bits(M) op1, bits(N) op2)
result = Zeros(M+N);
extended_op2 = ZeroExtend(op2, M+N);
for i=0 to M-1
if op1<i> == '1' then
result = result EOR LSL(extended_op2, i);
return result;
-/
def polynomial_mult (op1 : BitVec m) (op2 : BitVec n) : BitVec (m+n) :=
let result := 0#(m+n)
let extended_op2 := zeroExtend (m+n) op2
Expand Down
18 changes: 18 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_three_same.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat)
binary_vector_op_aux (e + 1) elems esize op x y result
termination_by (elems - e)

theorem binary_vector_op_aux_of_lt {n} {e elems} (h : e < elems) (esize op)
(x y result : BitVec n) :
binary_vector_op_aux e elems esize op x y result
= let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := op element1 element2
let result := elem_set result e esize elem_result
binary_vector_op_aux (e + 1) elems esize op x y result := by
conv => { lhs; unfold binary_vector_op_aux }
have : ¬(elems ≤ e) := by omega
simp only [this, ↓reduceIte]

theorem binary_vector_op_aux_of_not_lt {n} {e elems} (h : ¬(e < elems))
(esize op) (x y result : BitVec n) :
binary_vector_op_aux e elems esize op x y result = result := by
unfold binary_vector_op_aux
simp only [ite_eq_left_iff, Nat.not_le, h, false_implies]

/--
Perform pairwise op on esize-bit slices of x and y
-/
Expand Down
4 changes: 2 additions & 2 deletions Arm/Memory/MemoryProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_helper2
simp_all only [l1, decide_True, Bool.true_and, Nat.add_mod_mod]
rw [read_mem_bytes_of_write_mem_bytes_subset_helper1] <;> assumption
case neg =>
simp only [h₀, BitVec.bitvec_to_nat_of_nat, BitVec.toNat_append, Nat.testBit_or]
simp only [h₀, BitVec.toNat_ofNat, BitVec.toNat_append, Nat.testBit_or]
simp only [Nat.testBit_shiftLeft, Nat.testBit_mod_two_pow]
by_cases h₁ : (i < 8)
case pos => -- (i < 8)
Expand Down Expand Up @@ -948,7 +948,7 @@ private theorem write_mem_bytes_irrelevant_helper (h : n * 8 + 8 = (n + 1) * 8)
((BitVec.cast h (read_mem_bytes n (addr + 1#64) s ++ read_mem addr s)) >>> 8)) =
read_mem_bytes n (addr + 1#64) s := by
ext
simp [ushiftRight, ShiftRight.shiftRight, BitVec.bitvec_to_nat_of_nat]
simp [ushiftRight, ShiftRight.shiftRight, BitVec.toNat_ofNat]
have h_x_size := (read_mem_bytes n (addr + 1#64) s).isLt
have h_y_size := (read_mem addr s).isLt
generalize h_x : (BitVec.toNat (read_mem_bytes n (addr + 1#64) s)) = x
Expand Down
12 changes: 6 additions & 6 deletions Arm/Memory/SeparateProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ theorem n_minus_1_lt_2_64_1 (n : Nat)
(h1 : Nat.succ 0 ≤ n) (h2 : n < 2 ^ 64) :
(BitVec.ofNat 64 (n - 1)) < (BitVec.ofNat 64 (2^64 - 1)) := by
refine BitVec.val_bitvec_lt.mp ?a
simp [BitVec.bitvec_to_nat_of_nat]
simp [BitVec.toNat_ofNat]
have : n - 1 < 2 ^ 64 := by omega
simp_all [Nat.mod_eq_of_lt]
exact Nat.sub_lt_left_of_lt_add h1 h2
Expand Down Expand Up @@ -175,26 +175,26 @@ theorem first_addresses_add_one_preserves_subset_same_addr
rw [h3]
apply first_addresses_add_one_preserves_subset_same_addr_helper
rw [←BitVec.val_bitvec_lt]
simp [BitVec.bitvec_to_nat_of_nat]
simp [BitVec.toNat_ofNat]
simp_all [Nat.mod_eq_of_lt]
case inr =>
rename_i h3
have ⟨h3_0, h3_1⟩ := h3
rw [BitVec.add_sub_self_left_64] at h3_0
rw [BitVec.add_sub_self_left_64] at h3_0
rw [←BitVec.nat_bitvec_le] at h3_0
simp_all [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt]
simp_all [BitVec.toNat_ofNat, Nat.mod_eq_of_lt]
apply (BitVec.nat_bitvec_le ((BitVec.ofNat 64 m) - 1#64) ((BitVec.ofNat 64 n) - 1#64)).mp
rw [nat_bitvec_sub1]; rw [nat_bitvec_sub1]
simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt]
simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt]
· rw [Nat.mod_eq_of_lt h1u]
rw [Nat.mod_eq_of_lt h2u]
rw [Nat.mod_eq_of_lt (by omega)]
rw [Nat.mod_eq_of_lt (by omega)]
exact Nat.sub_le_sub_right h3_0 1
· simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt, h2u]
· simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt, h2u]
exact h2l
· simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt, h1u]
· simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt, h1u]
exact h1l
case right =>
rw [BitVec.add_sub_add_left]
Expand Down
188 changes: 165 additions & 23 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,139 @@ import Tactics.CSE
import Tactics.ClearNamed
import Arm.Memory.SeparateAutomation
import Arm.Syntax
import Correctness.ArmSpec

namespace GCMGmultV8Program
open ArmStateNotation

#genStepEqTheorems gcm_gmult_v8_program

/-
theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
rev_elems 128 8 ((BitVec.setWidth 64 x) ++ (BitVec.setWidth 64 (x >>> 64))) _p1 _p2 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
rw [rev_elems_64_8_append_eq_rev_elems_128_8]
done
-/

theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
-- rev_elems 64 8 (BitVec.setWidth 64 (x >>> 64)) _p1 _p2 ++
-- rev_elems 64 8 (BitVec.setWidth 64 x) _p3 _p4 := by
rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p1 _p2 ++
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
exact rfl
done

-- (TODO) Should we simply replace one function by the other here?
theorem gcm_polyval_mul_eq_polynomial_mult {x y : BitVec 128} :
GCMV8.gcm_polyval_mul x y = DPSFP.polynomial_mult x y := by
sorry

theorem eq_of_rev_elems_eq (x y : BitVec 128) (h : x = y) :
(rev_elems 128 8 x _p1 _p2 = rev_elems 128 8 y _p1 _p2) := by
congr

theorem pmull_op_e_0_eize_64_elements_1_size_128_eq (x y : BitVec 64) :
DPSFP.pmull_op 0 64 1 x y 0#128 =
DPSFP.polynomial_mult x y := by
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
simp only [state_simp_rules, bitvec_rules]
done

theorem rev_elems_128_8_eq_rev_elems_64_8_extractLsb' (x : BitVec 128) :
rev_elems 128 8 x _p1 _p2 =
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 ++ rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_128_8_eq_rev_elems_64_8_extractLsb'-51-2.lrat"
done

theorem rev_elems_64_8_append_eq_rev_elems_128_8 (x y : BitVec 64) :
rev_elems 64 8 x _p1 _p2 ++ rev_elems 64 8 y _p3 _p4 =
rev_elems 128 8 (y ++ x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_64_8_append_eq_rev_elems_128_8-60-2.lrat"
done

private theorem lsb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 64 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 0 64 x := by
bv_decide
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem msb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 0 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 64 64 x := by
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem zeroExtend_allOnes_lsh_64 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 64)
= 0x0000000000000000ffffffffffffffff#128 := by
decide

private theorem zeroExtend_allOnes_lsh_0 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 0) =
0xffffffffffffffff0000000000000000#128 := by
decide

private theorem BitVec.extractLsb'_64_128_of_appends (x y w z : BitVec 64) :
BitVec.extractLsb' 64 128 (x ++ y ++ (w ++ z)) =
y ++ w := by
bv_decide

private theorem BitVec.and_high_to_extractLsb'_concat (x : BitVec 128) :
x &&& 0xffffffffffffffff0000000000000000#128 = (BitVec.extractLsb' 64 64 x) ++ 0#64 := by
bv_decide

theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega
theorem BitVec.extractLsb'_append_eq (x : BitVec (n + n)) :
BitVec.extractLsb' n n x ++ BitVec.extractLsb' 0 n x = x := by
have h1 := @BitVec.append_of_extract_general (n + n) n n x
simp only [Nat.reduceAdd, BitVec.extractLsb'_eq] at h1
have h3 : BitVec.setWidth n (x >>> n) = BitVec.extractLsb' n n x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_setWidth, Fin.is_lt, decide_True, BitVec.getLsbD_ushiftRight,
Bool.true_and, BitVec.getLsbD_extractLsb']
simp_all only


/-
(TODO) Need a lemma like the following, which breaks up a polynomial
multiplication into four constituent ones, for normalization.
-/
example :
let p := 0b11#2
let q := 0b10#2
let w := 0b01#2
let z := 0b01#2
(DPSFP.polynomial_mult
(p ++ q)
(w ++ z))
=
((DPSFP.polynomial_mult p w) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q z)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p z) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q w) ++ 0#2) := by native_decide


set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
Expand Down Expand Up @@ -82,7 +182,12 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(sf, s0) ∧
-- Memory frame condition.
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 16)] (sf, s0) ∧
sf[r (.GPR 0) s0, 16] = GCMV8.GCMGmultV8_alt (HTable.extractLsb' 0 128) Xi := by
sf[r (.GPR 0) s0, 16] =
rev_elems 128 8
(GCMV8.GCMGmultV8_alt
(HTable.extractLsb' 0 128)
(rev_elems 128 8 Xi (by decide) (by decide)))
(by decide) (by decide) := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
Expand Down Expand Up @@ -121,6 +226,7 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
clear h_mem_sep h_run
/-
simp/ground below to reduce
(BitVec.extractLsb' 0 64
Expand All @@ -136,12 +242,48 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
-- (FIXME @bollu) cse leaves the goal unchanged here, quietly, likely due to
-- subexpressions occurring in dep. contexts. Maybe a message here would be helpful.
generalize h_Xi_rev : DPSFP.vrev128_64_8 Xi = Xi_rev
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)] at h_Xi_rev
generalize h_Xi_upper_rev : rev_elems 64 8 (BitVec.extractLsb' 64 64 Xi) (by decide) (by decide) = Xi_upper_rev
generalize h_Xi_lower_rev : rev_elems 64 8 (BitVec.extractLsb' 0 64 Xi) (by decide) (by decide) = Xi_lower_rev
-- Simplifying the RHS
simp only [←h_HTable, GCMV8.GCMGmultV8_alt,
simp only [GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval]
repeat rw [extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [extractLsb'_extractLsb'_zero_of_le (by decide)]
GCMV8.gcm_polyval,
←h_HTable, ←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
simp only [pmull_op_e_0_eize_64_elements_1_size_128_eq, gcm_polyval_mul_eq_polynomial_mult]
simp only [zeroExtend_allOnes_lsh_64, zeroExtend_allOnes_lsh_0]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [BitVec.xor_append]
repeat rw [BitVec.extractLsb'_append_right]
repeat rw [BitVec.extractLsb'_append_left]
repeat rw [BitVec.extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [BitVec.extractLsb'_extractLsb'_zero_of_le (by decide)]
rw [BitVec.and_high_to_extractLsb'_concat]
generalize h_HTable_upper : (BitVec.extractLsb' 64 64 HTable) = HTable_upper
generalize h_HTable_lower : (BitVec.extractLsb' 0 64 HTable) = HTable_lower
generalize h_term_u0u1 : (DPSFP.polynomial_mult HTable_upper Xi_upper_rev) = u0u1 at *
generalize h_term_l0l1 : (DPSFP.polynomial_mult HTable_lower Xi_lower_rev) = l0l1 at *
generalize h_term_1 : (DPSFP.polynomial_mult (BitVec.extractLsb' 128 64 HTable) (Xi_lower_rev ^^^ Xi_upper_rev) ^^^
BitVec.extractLsb' 64 128 (l0l1 ++ u0u1) ^^^
(u0u1 ^^^ l0l1)) = term_1
generalize h_term_2 : ((term_1 &&& 0xffffffffffffffff#128 ||| BitVec.zeroExtend 128 (BitVec.setWidth 64 u0u1) <<< 64) ^^^
DPSFP.polynomial_mult (BitVec.extractLsb' 0 64 u0u1) 0xc200000000000000#64)
= term_2
generalize h_term_3 : (BitVec.extractLsb' 64 128 (term_2 ++ term_2) ^^^
(BitVec.extractLsb' 64 64 l0l1 ++ 0x0#64 |||
BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 term_1) <<< 0))
= term_3
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [@rev_elems_64_8_append_eq_rev_elems_128_8 _ _ (by decide) (by decide) (by decide) (by decide)]
apply eq_of_rev_elems_eq
rw [@rev_elems_128_8_eq_rev_elems_64_8_extractLsb' _ (by decide) (by decide) (by decide) (by decide) (by decide)]
rw [h_Xi_upper_rev, h_Xi_lower_rev]
rw [BitVec.extractLsb'_append_eq]
simp [GCMV8.gcm_polyval_red]
-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
-- simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly, GCMV8.pmod, h_reduce]
-- repeat (unfold GCMV8.pmod.pmodTR; simp)

sorry
done
Expand Down
2 changes: 1 addition & 1 deletion Proofs/AES-GCM/GCMInitV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState)
Nat.zero_mod, Nat.zero_add, Nat.sub_zero, Nat.mul_one, Nat.zero_mul, Nat.one_mul,
Nat.reduceSub, BitVec.reduceMul, BitVec.reduceXOr, BitVec.mul_one, Nat.add_one_sub_one,
BitVec.one_mul]
-- bv_check "GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- bv_check "lrat_files/GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- TODO: proof works in vscode but timeout in the CI -- need to investigate further
-/

Binary file not shown.
Loading
Loading