From 6c25d630fff6bf4341c77aee76fd5fdd398965da Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sat, 21 Sep 2024 14:09:26 -1000 Subject: [PATCH] Update toolchain --- .github/workflows/benchmarks-pages.yaml | 6 +- .github/workflows/ci.yaml | 34 +- .github/workflows/coverage.yaml | 4 +- Cargo.lock | 7 - crates/prover/Cargo.toml | 1 - crates/prover/benches/fft.rs | 14 +- crates/prover/benches/merkle.rs | 2 +- .../src/constraint_framework/component.rs | 5 +- crates/prover/src/core/air/accumulation.rs | 1 + crates/prover/src/core/air/mod.rs | 8 +- .../prover/src/core/backend/cpu/quotients.rs | 13 +- .../src/core/backend/simd/bit_reverse.rs | 2 +- .../prover/src/core/backend/simd/blake2s.rs | 8 +- crates/prover/src/core/backend/simd/circle.rs | 11 +- .../prover/src/core/backend/simd/fft/ifft.rs | 13 +- .../prover/src/core/backend/simd/fft/mod.rs | 14 +- .../prover/src/core/backend/simd/fft/rfft.rs | 19 +- crates/prover/src/core/backend/simd/fri.rs | 15 +- .../src/core/backend/simd/lookups/gkr.rs | 2 +- .../src/core/backend/simd/lookups/mle.rs | 3 +- crates/prover/src/core/backend/simd/m31.rs | 576 +++++++++--------- .../prover/src/core/backend/simd/quotients.rs | 4 +- crates/prover/src/core/backend/simd/utils.rs | 82 +-- crates/prover/src/core/channel/blake2s.rs | 2 +- crates/prover/src/core/constraints.rs | 10 +- crates/prover/src/core/fri.rs | 7 +- crates/prover/src/core/lookups/gkr_prover.rs | 2 +- .../prover/src/core/lookups/gkr_verifier.rs | 2 +- crates/prover/src/core/pcs/mod.rs | 4 +- crates/prover/src/core/pcs/verifier.rs | 2 +- crates/prover/src/core/poly/circle/canonic.rs | 14 +- crates/prover/src/core/poly/circle/domain.rs | 5 +- crates/prover/src/core/poly/twiddles.rs | 1 + crates/prover/src/core/vcs/blake2_merkle.rs | 13 +- crates/prover/src/core/vcs/blake2s_ref.rs | 9 +- crates/prover/src/core/vcs/ops.rs | 11 +- crates/prover/src/core/vcs/prover.rs | 3 +- crates/prover/src/core/vcs/verifier.rs | 4 +- crates/prover/src/examples/poseidon/mod.rs | 2 +- crates/prover/src/lib.rs | 12 +- rust-toolchain.toml | 2 +- scripts/clippy.sh | 2 +- scripts/rust_fmt.sh | 2 +- scripts/test_avx.sh | 2 +- 44 files changed, 491 insertions(+), 464 deletions(-) diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 3672b6c2c..780273862 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -1,4 +1,4 @@ -name: +name: on: push: @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -29,7 +29,7 @@ jobs: - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: - tool: 'cargo' + tool: "cargo" output-file-path: output.txt github-token: ${{ secrets.GITHUB_TOKEN }} auto-push: true diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9e46888f3..2f692e751 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,9 +46,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 doc + - run: cargo +nightly-2024-09-21 doc run-wasm32-wasi-tests: runs-on: ubuntu-latest @@ -56,7 +56,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 targets: wasm32-wasi - uses: taiki-e/install-action@v2 with: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-09-21 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-09-21 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,7 +116,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -142,9 +142,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-09-21 test run-slow-tests: runs-on: ubuntu-latest @@ -152,9 +152,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --release --features="slow-tests" + - run: cargo +nightly-2024-09-21 test --release --features="slow-tests" udeps: runs-on: ubuntu-latest @@ -163,7 +163,7 @@ jobs: - uses: dtolnay/rust-toolchain@master name: "Rust Toolchain Setup" with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 id: "cache-cargo" - if: ${{ steps.cache-cargo.outputs.cache-hit != 'true' }} @@ -172,7 +172,7 @@ jobs: wget -O - -c https://github.com/est31/cargo-udeps/releases/download/v0.1.35/cargo-udeps-v0.1.35-x86_64-unknown-linux-gnu.tar.gz | tar -xz cargo-udeps-*/cargo-udeps udeps env: - RUSTUP_TOOLCHAIN: nightly-2024-01-04 + RUSTUP_TOOLCHAIN: nightly-2024-09-21 all-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 504cd67bb..402df96da 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-09-21 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-01-04 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2024-09-21 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/Cargo.lock b/Cargo.lock index c14183025..b71aed62e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -397,12 +397,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "downcast-rs" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" - [[package]] name = "educe" version = "0.5.11" @@ -984,7 +978,6 @@ dependencies = [ "bytemuck", "cfg-if", "criterion", - "downcast-rs", "educe", "hex", "itertools 0.12.1", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 62c516253..b45f4d7a4 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -14,7 +14,6 @@ blake2.workspace = true blake3.workspace = true bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } cfg-if = "1.0.0" -downcast-rs = "1.2" educe.workspace = true hex.workspace = true itertools.workspace = true diff --git a/crates/prover/benches/fft.rs b/crates/prover/benches/fft.rs index 35841d7e8..cbb0c9e80 100644 --- a/crates/prover/benches/fft.rs +++ b/crates/prover/benches/fft.rs @@ -29,7 +29,7 @@ pub fn simd_ifft(c: &mut Criterion) { || values.clone().data, |mut data| unsafe { ifft( - transmute(data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(data.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(log_size as usize), ); @@ -58,7 +58,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft_vecwise_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(9), black_box(0), @@ -72,7 +72,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft3_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs[3..]), black_box(7), black_box(4), @@ -91,7 +91,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || transpose_values.clone().data, |mut values| unsafe { transpose_vecs( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(TRANSPOSE_LOG_SIZE as usize - 4), ) }, @@ -115,8 +115,10 @@ pub fn simd_rfft(c: &mut Criterion) { target.set_len(values.data.len()); fft( - black_box(transmute(values.data.as_ptr())), - transmute(target.as_mut_ptr()), + black_box(transmute::<*const PackedBaseField, *const u32>( + values.data.as_ptr(), + )), + transmute::<*mut PackedBaseField, *mut u32>(target.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(LOG_SIZE as usize), ) diff --git a/crates/prover/benches/merkle.rs b/crates/prover/benches/merkle.rs index c039be77e..9a63a3c38 100644 --- a/crates/prover/benches/merkle.rs +++ b/crates/prover/benches/merkle.rs @@ -21,7 +21,7 @@ fn bench_blake2s_merkle>(c: &mut Criterion, id let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); group.throughput(Throughput::Elements(n_elements)); group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); - group.bench_function(&format!("{id} merkle"), |b| { + group.bench_function(format!("{id} merkle"), |b| { b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) }); } diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index c0d8319fa..48fe3561d 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -55,9 +55,10 @@ impl TraceLocationAllocator { } /// A component defined solely in means of the constraints framework. +/// /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for -/// the SIMD backend. -/// Note that the constraint framework only support components with columns of the same size. +/// the SIMD backend. Note that the constraint framework only supports components with columns of +/// the same size. pub trait FrameworkEval { fn log_size(&self) -> u32; diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 8fcf57549..0b7ce43df 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -1,4 +1,5 @@ //! Accumulators for a random linear combination of circle polynomials. +//! //! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is //! defined as //! f(p) = sum_i alpha^{N-1-i} u_i(P). diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index fcdd4d5f8..df0f38ce0 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -15,10 +15,10 @@ mod components; pub mod mask; /// Arithmetic Intermediate Representation (AIR). -/// An Air instance is assumed to already contain all the information needed to -/// evaluate the constraints. -/// For instance, all interaction elements are assumed to be present in it. -/// Therefore, an AIR is generated only after the initial trace commitment phase. +/// +/// An Air instance is assumed to already contain all the information needed to evaluate the +/// constraints. For instance, all interaction elements are assumed to be present in it. Therefore, +/// an AIR is generated only after the initial trace commitment phase. // TODO(spapini): consider renaming this struct. pub trait Air { fn components(&self) -> Vec<&dyn Component>; diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index 17cc007e1..4b356bb79 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -75,10 +75,10 @@ pub fn accumulate_row_quotients( row_accumulator } -/// Precompute the complex conjugate line coefficients for each column in each sample batch. -/// Specifically, for the i-th (in a sample batch) column's numerator term -/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: -/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). +/// Precomputes the complex conjugate line coefficients for each column in each sample batch. +/// +/// For the `i`-th (in a sample batch) column's numerator term `alpha^i * (c * F(p) - (a * p.y + +/// b))`, we precompute and return the constants: (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). pub fn column_line_coeffs( sample_batches: &[ColumnSampleBatch], random_coeff: SecureField, @@ -103,8 +103,9 @@ pub fn column_line_coeffs( .collect() } -/// Precompute the random coefficients used to linearly combine the batched quotients. -/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), +/// Precomputes the random coefficients used to linearly combine the batched quotients. +/// +/// For each sample batch we compute random_coeff^(number of columns in the batch), /// which is used to linearly combine the batch with the next one. pub fn batch_random_coeffs( sample_batches: &[ColumnSampleBatch], diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index 13d6585de..9bef5b31b 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -159,7 +159,7 @@ mod tests { let res = bit_reverse16(values.data.try_into().unwrap()); - assert_eq!(res.map(PackedM31::to_array).flatten(), expected); + assert_eq!(res.map(PackedM31::to_array).as_flattened(), expected); } #[test] diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index fbcfe89e2..d360e4f50 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -369,8 +369,12 @@ mod tests { let res_vectorized: [[u32; 8]; 16] = unsafe { transmute(untranspose_states(compress16( - transpose_states(transmute(states)), - transpose_msgs(transmute(msgs)), + transpose_states(transmute::, [u32x16; 8]>( + states, + )), + transpose_msgs(transmute::, [u32x16; 16]>( + msgs, + )), u32x16::splat(count_low), u32x16::splat(count_high), u32x16::splat(lastblock), diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index e930f77b2..cf069bb24 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -85,10 +85,7 @@ impl SimdBackend { // Generates twiddle steps for efficiently computing the twiddles. // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { + fn twiddle_steps(mappings: &[F]) -> Vec { let mut denominators: Vec = vec![mappings[0]]; for i in 1..mappings.len() { @@ -151,7 +148,7 @@ impl PolyOps for SimdBackend { // Safe because [PackedBaseField] is aligned on 64 bytes. unsafe { ifft::ifft( - transmute(values.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.data.as_mut_ptr()), &twiddles, log_size as usize, ); @@ -254,8 +251,8 @@ impl PolyOps for SimdBackend { // FFT from the coefficients buffer to the values chunk. unsafe { rfft::fft( - transmute(poly.coeffs.data.as_ptr()), - transmute( + transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>( values[i << (fft_log_size - LOG_N_LANES) ..(i + 1) << (fft_log_size - LOG_N_LANES)] .as_mut_ptr(), diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index eb34da490..881e2166a 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -589,7 +589,7 @@ mod tests { let mut res = values; unsafe { ifft3( - transmute(res.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -655,7 +655,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + assert_eq!(res, ground_truth_ifft(domain, values.as_flattened())); } #[test] @@ -669,7 +669,7 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft_lower_with_vecwise( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -691,11 +691,14 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); } assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ca44979e8..5308ba6f9 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -102,19 +102,19 @@ fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { // TODO: For architectures that when multiplying require doubling then the twiddles // should be precomputed as double. For other architectures, the twiddle should be // precomputed without doubling. - crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + crate::core::backend::simd::m31::mul_doubled_wasm(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx512(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx2(v, twiddle_dbl) } else { - crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_simd(v, twiddle_dbl) } } } diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 6d51fd09d..64c221ff3 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -615,8 +615,8 @@ mod tests { let mut res = values; unsafe { fft3( - transmute(res.as_ptr()), - transmute(res.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -686,7 +686,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_fft(domain, values.flatten())); + assert_eq!(res, ground_truth_fft(domain, values.as_flattened())); } #[test] @@ -700,8 +700,8 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { fft_lower_with_vecwise( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -722,10 +722,13 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); fft( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 97212497b..626e9ef95 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -1,5 +1,5 @@ use std::array; -use std::simd::u32x8; +use std::simd::{u32x16, u32x8}; use num_traits::Zero; @@ -37,14 +37,15 @@ impl FriOps for SimdBackend { let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - let twiddle_dbl: [u32; 16] = - array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let value = { + let twiddle_dbl = u32x16::from_array(array::from_fn(|i| unsafe { + *itwiddles.get_unchecked(vec_index * 16 + i) + })); + let val0 = unsafe { eval.values.packed_at(vec_index * 2) }.into_packed_m31s(); + let val1 = unsafe { eval.values.packed_at(vec_index * 2 + 1) }.into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + simd_ibutterfly(a, b, twiddle_dbl) }); let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..74d7f7c43 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -25,7 +25,7 @@ impl GkrOps for SimdBackend { } // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. - let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let (y_rem, y_last_chunk) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); assert_eq!(initial.len(), N_LANES); diff --git a/crates/prover/src/core/backend/simd/lookups/mle.rs b/crates/prover/src/core/backend/simd/lookups/mle.rs index 0e2fe73f7..07f175bbc 100644 --- a/crates/prover/src/core/backend/simd/lookups/mle.rs +++ b/crates/prover/src/core/backend/simd/lookups/mle.rs @@ -30,9 +30,8 @@ impl MleOps for SimdBackend { let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); let res = zip(evals_at_0x, evals_at_1x) - .enumerate() // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { + .map(|(&packed_eval_at_0iv, &packed_eval_at_1iv)| { fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) }) .collect(); diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index f6291626b..babe7cad7 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -3,14 +3,13 @@ use std::mem::transmute; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ptr; use std::simd::cmp::SimdOrd; -use std::simd::{u32x16, Simd, Swizzle}; +use std::simd::{u32x16, Simd}; use bytemuck::{Pod, Zeroable}; use num_traits::{One, Zero}; use rand::distributions::{Distribution, Standard}; use super::qm31::PackedQM31; -use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -142,16 +141,16 @@ impl Mul for PackedM31 { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - _mul_neon(self, rhs) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - _mul_wasm(self, rhs) + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + mul_neon(self, rhs) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + mul_wasm(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - _mul_avx512(self, rhs) + mul_avx512(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - _mul_avx2(self, rhs) + mul_avx2(self, rhs) } else { - _mul_simd(self, rhs) + mul_simd(self, rhs) } } } @@ -286,290 +285,299 @@ impl Sum for PackedM31 { } } -/// Returns `a * b`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { - use core::arch::aarch64::{int32x2_t, vqdmull_s32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} +cfg_if::cfg_if! { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + use core::arch::aarch64::{uint32x2_t, vmull_u32, int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + /// Returns `a * b`. + pub(crate) fn mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::aarch64::{uint32x2_t, vmull_u32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; -/// Returns `a * b`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_wasm(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_wasm(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; - use std::simd::u32x4; - - let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; - let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; - - let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; - let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; - let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; - let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; - let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; - let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; - let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; - let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; - - let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); - let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); - let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); - let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); - c0_even >>= 1; - c1_even >>= 1; - c2_even >>= 1; - c3_even >>= 1; - - let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; - let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; - - even + odd -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx512(a, b.0 + b.0) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; - - let a: __m512i = unsafe { transmute(a) }; - let b_double: __m512i = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = a; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { _mm512_srli_epi64(a, 32) }; - - let b_dbl_e = b_double; - let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; - - // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. - let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; - let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + /// Returns `a * b`. + pub(crate) fn mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx512(a, b.0 + b.0) + } -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx2(a, b.0 + b.0) -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; - - let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; - let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a0_e = a0; - let a1_e = a1; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; - let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; - - let b0_dbl_e = b0_dbl; - let b1_dbl_e = b1_dbl; - let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; - let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; - - // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. - let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; - let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; - let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; - let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; - - let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; - let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_simd(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx2(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -/// -/// `b_double` should be in the range `[0, 2P]`. -pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { - const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; - - let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; - let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; - - // To compute prod = a * b start by multiplying - // a_e/o by b_dbl_e/o. - let prod_e_dbl = a_e * b_dbl_e; - let prod_o_dbl = a_o * b_dbl_o; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, - // prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - let mut prod_lows = InterleaveEvens::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - // Divide by 2: - prod_lows >>= 1; - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_highs = InterleaveOdds::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - - // prod_hs - |0|prod_o_h|0|prod_e_h| - PackedM31(prod_lows) + PackedM31(prod_highs) + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1]: [__m256i; 2] = unsafe { transmute::(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute::(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else { + use std::simd::Swizzle; + + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + pub(crate) fn mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_simd(a, b.0 + b.0) + } + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = + unsafe { transmute::, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::>(a) >> 32 }; + + let b_dbl_e = unsafe { + transmute::, Simd>(b_double) & MASK_EVENS + }; + let b_dbl_o = + unsafe { transmute::, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) + } + } } #[cfg(test)] diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index ff8e1c580..9b5b02db5 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -290,13 +290,13 @@ mod tests { let e1: BaseColumn = (0..small_domain.size()) .map(|i| BaseField::from(2 * i)) .collect(); - let polys = vec![ + let polys = [ CircleEvaluation::::new(small_domain, e0) .interpolate(), CircleEvaluation::::new(small_domain, e1) .interpolate(), ]; - let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let columns = [polys[0].evaluate(domain), polys[1].evaluate(domain)]; let random_coeff = qm31!(1, 2, 3, 4); let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index 87dfd2246..f672520b7 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,52 +1,58 @@ -use std::simd::Swizzle; - -/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. -pub struct InterleaveEvens; - -impl Swizzle for InterleaveEvens { - const INDEX: [usize; N] = parity_interleave(false); -} +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "wasm32", target_feature = "simd128") +)))] +pub mod swizzle { + use std::simd::Swizzle; + + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. + pub struct InterleaveEvens; + + impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); + } -/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. -pub struct InterleaveOdds; + /// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. + pub struct InterleaveOdds; -impl Swizzle for InterleaveOdds { - const INDEX: [usize; N] = parity_interleave(true); -} + impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); + } -const fn parity_interleave(odd: bool) -> [usize; N] { - let mut res = [0; N]; - let mut i = 0; - while i < N { - res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; - i += 1; + const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res } - res -} -#[cfg(test)] -mod tests { - use std::simd::{u32x4, Swizzle}; + #[cfg(test)] + mod tests { + use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use super::{InterleaveEvens, InterleaveOdds}; - #[test] - fn interleave_evens() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); - let res = InterleaveEvens::concat_swizzle(lo, hi); + let res = InterleaveEvens::concat_swizzle(lo, hi); - assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); - } + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } - #[test] - fn interleave_odds() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); - let res = InterleaveOdds::concat_swizzle(lo, hi); + let res = InterleaveOdds::concat_swizzle(lo, hi); - assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } } } diff --git a/crates/prover/src/core/channel/blake2s.rs b/crates/prover/src/core/channel/blake2s.rs index 3887b0b5f..dd2ca5b2f 100644 --- a/crates/prover/src/core/channel/blake2s.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -74,7 +74,7 @@ impl Channel for Blake2sChannel { msg[1] = (nonce >> 32) as u32; let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); - self.update_digest(unsafe { std::mem::transmute(res) }); + self.update_digest(unsafe { std::mem::transmute::<[u32; 8], Blake2sHash>(res) }); } fn draw_felt(&mut self) -> SecureField { diff --git a/crates/prover/src/core/constraints.rs b/crates/prover/src/core/constraints.rs index 31711d98e..f66c8d93d 100644 --- a/crates/prover/src/core/constraints.rs +++ b/crates/prover/src/core/constraints.rs @@ -90,11 +90,11 @@ pub fn complex_conjugate_line( / (point.complex_conjugate().y - point.y) } -/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, -/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and -/// (conj(sample.y), conj(sample.value)). -/// Relies on the fact that every polynomial F over the base -/// field holds: F(p*) == F(p)* (* being the complex conjugate). +/// Evaluates the coefficients of a line between a point and its complex conjugate. +/// +/// Specifically, `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and +/// (conj(sample.y), conj(sample.value)). Relies on the fact that every polynomial F over the base +/// field holds: `F(p*) == F(p)*` (`*` being the complex conjugate). pub fn complex_conjugate_line_coeffs( sample: &PointSample, alpha: SecureField, diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 9f13de253..41b79b58a 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -97,8 +97,7 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps /// Let `src` be the evaluation of a circle polynomial `f` on a /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The - /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + - /// f'`. + /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + f'`. /// /// # Panics /// @@ -728,7 +727,7 @@ impl FriLayerVerifier { let mut all_subline_evals = Vec::new(); // Group queries by the subline they reside in. - for subline_queries in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + for subline_queries in queries.chunk_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { let subline_start = (subline_queries[0] >> FOLD_STEP) << FOLD_STEP; let subline_end = subline_start + (1 << FOLD_STEP); @@ -801,7 +800,7 @@ impl, H: MerkleHasher> FriLayerProver { // Group queries by the subline they reside in. // TODO(andrew): Explain what a "subline" is at the top of the module. - for query_group in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + for query_group in queries.chunk_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { let subline_start = (query_group[0] >> FOLD_STEP) << FOLD_STEP; let subline_end = subline_start + (1 << FOLD_STEP); diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..71c2763f4 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -470,7 +470,7 @@ pub fn prove_batch( // Seed the channel with the layer masks. for (&instance, mask) in zip(&sumcheck_instances, &masks) { - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); layer_masks_by_instance[instance].push(mask.clone()); } diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..69d8314f5 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -120,7 +120,7 @@ pub fn partially_verify_batch( for &instance in &sumcheck_instances { let n_unused = n_layers - instance_n_layers(instance); let mask = &layer_masks_by_instance[instance][layer - n_unused]; - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); } // Set the OOD evaluation point for layer above. diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index d9acf524b..40f07c504 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -1,4 +1,5 @@ //! Implements a FRI polynomial commitment scheme. +//! //! This is a protocol where the prover can commit on a set of polynomials and then prove their //! opening on a set of points. //! Note: This implementation is not really a polynomial commitment scheme, because we are not in @@ -34,7 +35,8 @@ impl Default for PcsConfig { fn default() -> Self { Self { pow_bits: 5, - fri_config: FriConfig::new(0, 1, 3), + // fri_config: FriConfig::new(0, 1, 3), + fri_config: FriConfig::new(0, 1, 50), } } } diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 81107b455..2182a88b6 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -104,7 +104,7 @@ impl CommitmentSchemeVerifier { }) .0 .into_iter() - .collect::>()?; + .collect::>()?; // Answer FRI queries. let samples = sampled_points diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index 837e648d9..4d068594d 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -2,12 +2,14 @@ use super::CircleDomain; use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; -/// A coset of the form G_{2n} + , where G_n is the generator of the -/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. -/// These cosets can be used as a [CircleDomain], and be interpolated on. -/// Note that this changes the ordering on the coset to be like [CircleDomain], -/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. -/// For example, the Xs below are a canonic coset with n=8. +/// A coset of the form `G_{2n} + `, where `G_n` is the generator of the subgroup of order `n`. +/// +/// The ordering on this coset is `G_2n + i * G_n`. +/// These cosets can be used as a [`CircleDomain`], and be interpolated on. +/// Note that this changes the ordering on the coset to be like [`CircleDomain`], +/// which is `G_{2n} + i * G_{n/2}` and then `-G_{2n} -i * G_{n/2}`. +/// For example, the `X`s below are a canonic coset with `n=8`. +/// /// ```text /// X O X /// O O diff --git a/crates/prover/src/core/poly/circle/domain.rs b/crates/prover/src/core/poly/circle/domain.rs index fba2bc3fb..b6ced6018 100644 --- a/crates/prover/src/core/poly/circle/domain.rs +++ b/crates/prover/src/core/poly/circle/domain.rs @@ -10,8 +10,9 @@ use crate::core::fields::m31::BaseField; pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; /// A valid domain for circle polynomial interpolation and evaluation. -/// Valid domains are a disjoint union of two conjugate cosets: +-C + . -/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. +/// +/// Valid domains are a disjoint union of two conjugate cosets: `+-C + `. +/// The ordering defined on this domain is `C + iG_n`, and then `-C - iG_n`. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CircleDomain { pub half_coset: Coset, diff --git a/crates/prover/src/core/poly/twiddles.rs b/crates/prover/src/core/poly/twiddles.rs index 53ea476c6..473d7c709 100644 --- a/crates/prover/src/core/poly/twiddles.rs +++ b/crates/prover/src/core/poly/twiddles.rs @@ -2,6 +2,7 @@ use super::circle::PolyOps; use crate::core::circle::Coset; /// Precomputed twiddles for a specific coset tower. +/// /// A coset tower is every repeated doubling of a `root_coset`. /// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as /// its `half_coset`. diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 293ed4ab3..f5a1f5762 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -20,7 +20,7 @@ impl MerkleHasher for Blake2sMerkleHasher { if let Some((left, right)) = children_hashes { state = compress( state, - unsafe { std::mem::transmute([left, right]) }, + unsafe { std::mem::transmute::<[Blake2sHash; 2], [u32; 16]>([left, right]) }, 0, 0, 0, @@ -33,9 +33,16 @@ impl MerkleHasher for Blake2sMerkleHasher { .copied() .chain(std::iter::repeat(BaseField::zero()).take(rem)); for chunk in padded_values.array_chunks::<16>() { - state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + state = compress( + state, + unsafe { std::mem::transmute::<[BaseField; 16], [u32; 16]>(chunk) }, + 0, + 0, + 0, + 0, + ); } - state.map(|x| x.to_le_bytes()).flatten().into() + state.map(|x| x.to_le_bytes()).as_flattened().into() } } diff --git a/crates/prover/src/core/vcs/blake2s_ref.rs b/crates/prover/src/core/vcs/blake2s_ref.rs index ab32ea6d9..0630f8176 100644 --- a/crates/prover/src/core/vcs/blake2s_ref.rs +++ b/crates/prover/src/core/vcs/blake2s_ref.rs @@ -1,4 +1,5 @@ //! An AVX512 implementation of the BLAKE2s compression function. +//! //! Based on . pub const IV: [u32; 8] = [ @@ -30,22 +31,22 @@ fn xor(a: u32, b: u32) -> u32 { #[inline(always)] fn rot16(x: u32) -> u32 { - (x >> 16) | (x << (32 - 16)) + x.rotate_left(16) } #[inline(always)] fn rot12(x: u32) -> u32 { - (x >> 12) | (x << (32 - 12)) + x.rotate_right(12) } #[inline(always)] fn rot8(x: u32) -> u32 { - (x >> 8) | (x << (32 - 8)) + x.rotate_right(8) } #[inline(always)] fn rot7(x: u32) -> u32 { - (x >> 7) | (x << (32 - 7)) + x.rotate_right(7) } #[inline(always)] diff --git a/crates/prover/src/core/vcs/ops.rs b/crates/prover/src/core/vcs/ops.rs index 14093e536..b40a91bef 100644 --- a/crates/prover/src/core/vcs/ops.rs +++ b/crates/prover/src/core/vcs/ops.rs @@ -6,13 +6,12 @@ use crate::core::backend::{Col, ColumnOps}; use crate::core::fields::m31::BaseField; use crate::core::vcs::hash::Hash; -/// A Merkle node hash is a hash of: -/// [left_child_hash, right_child_hash], column0_value, column1_value, ... -/// "[]" denotes optional values. +/// A Merkle node hash is a hash of: `[left_child_hash, right_child_hash], column0_value, +/// column1_value, ...` where `[]` denotes optional values. +/// /// The largest Merkle layer has no left and right child hashes. The rest of the layers have -/// children hashes. -/// At each layer, the tree may have multiple columns of the same length as the layer. -/// Each node in that layer contains one value from each column. +/// children hashes. At each layer, the tree may have multiple columns of the same length as the +/// layer. Each node in that layer contains one value from each column. pub trait MerkleHasher: Debug + Default + Clone { type Hash: Hash; /// Hashes a single Merkle node. See [MerkleHasher] for more details. diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 6312de114..77c8bf2da 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -64,8 +64,7 @@ impl, H: MerkleHasher> MerkleProver { /// /// # Arguments /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// * `queries_per_log_size` - Maps a log_size to a vector of queries for columns of that size. /// * `columns` - A vector of references to columns. /// /// # Returns diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 53346bb93..57738afa2 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -28,9 +28,9 @@ impl MerkleVerifier { /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// log_size. /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. + /// vector of queried values to that column. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index d25cc1865..b3ebfb5a9 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -445,7 +445,7 @@ mod tests { for i in 0..16 { internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); } - let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); + let matrix = RowMajorMatrix::::new(internal_matrix.as_flattened().to_vec()); let expected_state = matrix.mul(state); apply_internal_round_matrix(&mut state); diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 1e9c3be74..166af9156 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,21 +1,19 @@ #![allow(incomplete_features)] +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] #![feature( array_chunks, - array_methods, array_try_from_fn, assert_matches, exact_size_is_empty, generic_const_exprs, get_many_mut, int_roundings, - is_sorted, iter_array_chunks, - new_uninit, portable_simd, - slice_first_last_chunk, - slice_flatten, - slice_group_by, - stdsimd + trait_upcasting )] pub mod constraint_framework; pub mod core; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a0f1a930e..b6930c7af 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-04" +channel = "nightly-2024-09-21" diff --git a/scripts/clippy.sh b/scripts/clippy.sh index 8361cd25d..75b7243a3 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2024-09-21 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index e4223f999..84c93820a 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 fmt --all -- "$@" +cargo +nightly-2024-09-21 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index d911a2479..5ac74726c 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-01-04 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-09-21 test "$@"