Skip to content

Commit

Permalink
Mut chunks for VeryPacked Base and Secure columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Sep 18, 2024
1 parent a51f630 commit 2c88349
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
73 changes: 71 additions & 2 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use num_traits::Zero;
use super::cm31::PackedCM31;
use super::m31::{PackedBaseField, N_LANES};
use super::qm31::{PackedQM31, PackedSecureField};
use super::very_packed_m31::{VeryPackedBaseField, VeryPackedSecureField, N_VERY_PACKED_ELEMS};
use super::very_packed_m31::{
VeryPackedBaseField, VeryPackedQM31, VeryPackedSecureField, N_VERY_PACKED_ELEMS,
};
use super::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::cm31::CM31;
Expand Down Expand Up @@ -213,6 +215,21 @@ impl<'a> BaseColumnMutSlice<'a> {
}
}

pub struct VeryPackedBaseColumnMutSlice<'a>(pub &'a mut [VeryPackedBaseField]);

impl<'a> VeryPackedBaseColumnMutSlice<'a> {
const N_ELEMS: usize = N_LANES * N_VERY_PACKED_ELEMS;
pub fn at(&self, index: usize) -> BaseField {
self.0[index / Self::N_ELEMS].to_array()[index % N_LANES]
}

pub fn set(&mut self, index: usize, value: BaseField) {
let mut packed = self.0[index / Self::N_ELEMS].to_array();
packed[index % Self::N_ELEMS] = value;
self.0[index / Self::N_ELEMS] = VeryPackedBaseField::from_array(packed)
}
}

/// An efficient structure for storing and operating on a arbitrary number of [`SecureField`]
/// values.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -344,6 +361,33 @@ impl<'a> SecureColumnByCoordsMutSlice<'a> {
}
}

/// A mutable slice of a SecureColumnByCoords.
pub struct VeryPackedSecureColumnByCoordsMutSlice<'a>(
pub [VeryPackedBaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE],
);

impl<'a> VeryPackedSecureColumnByCoordsMutSlice<'a> {
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> VeryPackedSecureField {
VeryPackedQM31::from_very_packed_m31s(std::array::from_fn(|i| {
*self.0[i].0.get_unchecked(vec_index)
}))
}

/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: VeryPackedSecureField) {
let [a, b, c, d] = value.into_very_packed_m31s();
*self.0[0].0.get_unchecked_mut(vec_index) = a;
*self.0[1].0.get_unchecked_mut(vec_index) = b;
*self.0[2].0.get_unchecked_mut(vec_index) = c;
*self.0[3].0.get_unchecked_mut(vec_index) = d;
}
}

impl SecureColumnByCoords<SimdBackend> {
pub fn packed_len(&self) -> usize {
self.columns[0].data.len()
Expand Down Expand Up @@ -425,6 +469,13 @@ impl VeryPackedBaseColumn {
pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self {
&*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn)
}

pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<VeryPackedBaseColumnMutSlice<'_>> {
self.data
.chunks_mut(chunk_size)
.map(VeryPackedBaseColumnMutSlice)
.collect_vec()
}
}

impl From<BaseColumn> for VeryPackedBaseColumn {
Expand Down Expand Up @@ -462,7 +513,11 @@ impl Column<BaseField> for VeryPackedBaseColumn {
}

fn to_cpu(&self) -> Vec<BaseField> {
todo!()
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}

fn len(&self) -> usize {
Expand Down Expand Up @@ -571,6 +626,20 @@ impl VeryPackedSecureColumnByCoords {
pub unsafe fn transform_under_mut(value: &mut SecureColumnByCoords<SimdBackend>) -> &mut Self {
&mut *(std::ptr::addr_of!(*value) as *mut VeryPackedSecureColumnByCoords)
}

pub fn chunks_mut(
&mut self,
chunk_size: usize,
) -> Vec<VeryPackedSecureColumnByCoordsMutSlice<'_>> {
let [a, b, c, d] = self
.columns
.get_many_mut([0, 1, 2, 3])
.unwrap()
.map(|x| x.chunks_mut(chunk_size));
izip!(a, b, c, d)
.map(|(a, b, c, d)| VeryPackedSecureColumnByCoordsMutSlice([a, b, c, d]))
.collect_vec()
}
}

#[cfg(test)]
Expand Down
10 changes: 10 additions & 0 deletions crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ impl<A, const N: usize> Vectorized<A, N> {
}
}

impl<A, const N: usize> From<[A; N]> for Vectorized<A, N> {
fn from(array: [A; N]) -> Self {
Vectorized(array)
}
}

unsafe impl<A, const N: usize> Zeroable for Vectorized<A, N> {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
Expand Down Expand Up @@ -79,6 +85,10 @@ impl VeryPackedQM31 {
pub fn from_very_packed_m31s([a, b, c, d]: [VeryPackedM31; 4]) -> Self {
Self::from_fn(|i| PackedQM31::from_packed_m31s([a.0[i], b.0[i], c.0[i], d.0[i]]))
}

pub fn into_very_packed_m31s(self) -> [VeryPackedM31; 4] {
std::array::from_fn(|i| VeryPackedM31::from(self.0.map(|v| v.into_packed_m31s()[i])))
}
}
impl From<M31> for VeryPackedM31 {
fn from(v: M31) -> Self {
Expand Down

0 comments on commit 2c88349

Please sign in to comment.