diff --git a/erasurecoding/src/lib.rs b/erasurecoding/src/lib.rs index 6f77750..47b75a3 100644 --- a/erasurecoding/src/lib.rs +++ b/erasurecoding/src/lib.rs @@ -1,14 +1,18 @@ use libc::{c_int, size_t}; use std::slice; +const MAX_SHARDS: usize = 65535; + /// Reed-Solomon encode function for FFI /// Takes sharded original data as byte array and produces recovery shards /// Parameters: /// - original_shards: number of original data shards +/// - original_shards_len: length of original_shards /// - recovery_shards: number of recovery shards to generate /// - original_shards: input byte array containing flattened shards /// - shard_size: size of each shard in bytes /// - recovery_shards_out: buffer to store generated and flattened recovery shards +/// - recovery_shards_out_len: length of recovery_shards_out /// Returns 0 on success, -1 on error #[no_mangle] pub unsafe extern "C" fn reed_solomon_encode( @@ -16,29 +20,58 @@ pub unsafe extern "C" fn reed_solomon_encode( recovery_shards_count: size_t, shard_size: size_t, original_shards: *const u8, + original_shards_len: size_t, recovery_shards_out: *mut u8, + recovery_shards_out_len: size_t, ) -> c_int { - let original_shards: Vec<_> = slice::from_raw_parts( - original_shards, - original_shards_count as usize * shard_size as usize, - ) - .chunks(shard_size as usize) - .collect(); + if let Some(sum) = original_shards_count.checked_add(recovery_shards_count) { + if sum > MAX_SHARDS { + return -1; + } + } else { + // Overflow. + return -1; + } + + if !(original_shards_count > 0 && recovery_shards_count > 0) { + return -1; + } + + if original_shards.is_null() || recovery_shards_out.is_null() { + return -1; + } + + // Shard size must be a multiple of 2. + if !(shard_size > 0 && shard_size % 2 == 0) { + return -1; + } + + if original_shards_len % shard_size != 0 + || original_shards_len / shard_size != original_shards_count + { + return -1; + } + + if recovery_shards_out_len != recovery_shards_count * shard_size { + return -1; + } + + let original_shards = + slice::from_raw_parts(original_shards, original_shards_count * shard_size) + .chunks(shard_size); match reed_solomon_simd::encode( - original_shards_count as usize, - recovery_shards_count as usize, + original_shards_count, + recovery_shards_count, original_shards, ) { Ok(recovery) => { - let output_slice = slice::from_raw_parts_mut( - recovery_shards_out, - recovery_shards_count as usize * shard_size as usize, - ); + let output_slice = + slice::from_raw_parts_mut(recovery_shards_out, recovery_shards_count * shard_size); for (i, shard) in recovery.iter().enumerate() { - let start = i * shard_size as usize; - let end = start + shard_size as usize; + let start = i * shard_size; + let end = start + shard_size; output_slice[start..end].copy_from_slice(shard); } 0 @@ -62,6 +95,7 @@ pub unsafe extern "C" fn reed_solomon_encode( /// - recovery_shards_len: length of recovery_shards_data /// - recovery_shards_indexes: indexes of recovery_shard_data /// - recovered_shards_out: buffer for recovered missing original shards +/// - recovered_shards_out_len: length of recovered_shards_out /// - recovered_shards_indexes_out: buffer for indexes of recovered original shards /// Returns 0 on success, -1 on error #[no_mangle] @@ -76,24 +110,68 @@ pub unsafe extern "C" fn reed_solomon_decode( recovery_shards_len: size_t, recovery_shards_indexes: *const size_t, recovered_shards_out: *mut u8, + recovered_shards_out_len: size_t, recovered_shards_indexes_out: *mut size_t, ) -> c_int { + if let Some(sum) = original_shards_count.checked_add(recovery_shards_count) { + if sum > MAX_SHARDS { + return -1; + } + } else { + // Overflow. + return -1; + } + + if original_shards.is_null() + || original_shards_indexes.is_null() + || recovery_shards.is_null() + || recovery_shards_indexes.is_null() + || recovered_shards_out.is_null() + || recovered_shards_indexes_out.is_null() + { + return -1; + } + + if !(original_shards_count > 0 && recovery_shards_count > 0) { + return -1; + } + + // Shard size must be a multiple of 2. + if !(shard_size > 0 && shard_size % 2 == 0) { + return -1; + } + + if original_shards_len % shard_size != 0 + || recovery_shards_len % shard_size != 0 + || recovery_shards_len % shard_size != 0 + { + return -1; + } + + // Expected recovered shards are original shards count - original shards + // provided. Since we only get back missing original shards. + if recovered_shards_out_len + != shard_size * (original_shards_count - (original_shards_len / shard_size)) + { + return -1; + } + // Create original shard pairs - let orig_data = slice::from_raw_parts(original_shards, original_shards_len); - let orig_indexes = + let original_shards = slice::from_raw_parts(original_shards, original_shards_len); + let original_shards_indexes = slice::from_raw_parts(original_shards_indexes, original_shards_len / shard_size); - let original_inputs = orig_indexes + let original_inputs = original_shards_indexes .iter() - .zip(orig_data.chunks(shard_size)) + .zip(original_shards.chunks(shard_size)) .map(|(&idx, chunk)| (idx, chunk)); // Create recovery shard pairs - let rec_data = slice::from_raw_parts(recovery_shards, recovery_shards_len); - let rec_indexes = + let recovery_shards = slice::from_raw_parts(recovery_shards, recovery_shards_len); + let recovery_shards_indexes = slice::from_raw_parts(recovery_shards_indexes, recovery_shards_len / shard_size); - let recovery_inputs = rec_indexes + let recovery_inputs = recovery_shards_indexes .iter() - .zip(rec_data.chunks(shard_size)) + .zip(recovery_shards.chunks(shard_size)) .map(|(&idx, chunk)| (idx, chunk)); match reed_solomon_simd::decode( @@ -103,20 +181,420 @@ pub unsafe extern "C" fn reed_solomon_decode( recovery_inputs, ) { Ok(restored) => { - let shards_recovered_out = + let recovered_shards_out = slice::from_raw_parts_mut(recovered_shards_out, restored.len() * shard_size); - let shards_recovered_indexes_out = + let recovered_shards_indexes_out = slice::from_raw_parts_mut(recovered_shards_indexes_out, restored.len()); for (i, (&shard_index, shard_data)) in restored.iter().enumerate() { let start = i * shard_size; let end = start + shard_size; - shards_recovered_out[start..end].copy_from_slice(shard_data); - shards_recovered_indexes_out[i] = shard_index; + recovered_shards_out[start..end].copy_from_slice(shard_data); + recovered_shards_indexes_out[i] = shard_index; } 0 } Err(_) => -1, } } + +#[cfg(test)] +mod tests { + use super::*; + use std::ptr; + + #[test] + fn test_encode_success() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8]; + let mut recovery_shards_out: [u8; 4] = [0; 4]; + + unsafe { + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + recovery_shards_out.as_mut_ptr(), + recovery_shards_out.len(), + ); + + assert_eq!(result, 0); // Success. + assert_eq!(recovery_shards_out, [4, 4, 4, 12]); + } + } + + #[test] + fn test_encode_minimum_shards() { + let original_shards_count = 1; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 4] = [1, 2, 3, 4]; + let mut recovery_shards_out: [u8; 4] = [0; 4]; + + unsafe { + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + recovery_shards_out.as_mut_ptr(), + recovery_shards_out.len(), + ); + + assert_eq!(result, 0); // Success. + assert_eq!(recovery_shards_out, [1, 2, 3, 4]); + } + } + #[test] + fn test_encode_invalid_shard_size() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 9] = [1, 2, 3, 4, 5, 6, 7, 8, 9]; // Invalid shard size. + let mut recovery_shards_out: [u8; 4] = [0; 4]; + + unsafe { + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + recovery_shards_out.as_mut_ptr(), + recovery_shards_out.len(), + ); + + assert_eq!(result, -1); + } + } + + #[test] + fn test_encode_mismatched_shard_count() { + let original_shards_count = 3; // Invalid shard count. + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8]; + let mut recovery_shards_out: [u8; 4] = [0; 4]; + + unsafe { + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + recovery_shards_out.as_mut_ptr(), + recovery_shards_out.len(), + ); + + assert_eq!(result, -1); + } + } + + #[test] + fn test_encode_null_pointer() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8]; + let mut recovery_shards_out: [u8; 4] = [0; 4]; + + unsafe { + // Test null pointer for original_shards. + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + ptr::null(), + original_shards.len(), + recovery_shards_out.as_mut_ptr(), + recovery_shards_out.len(), + ); + assert_eq!(result, -1); + + // Test null pointer for recovery_shards_out. + let result = reed_solomon_encode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + ptr::null_mut(), + recovery_shards_out.len(), + ); + assert_eq!(result, -1); + } + } + + #[test] + fn test_decode_success() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 4] = [5, 6, 7, 8]; + let original_shards_indexes: [usize; 1] = [1]; + + let recovery_shards: [u8; 4] = [4, 4, 4, 12]; + let recovery_shards_indexes: [usize; 1] = [0]; + + let mut recovered_shards_out: [u8; 4] = [0; 4]; + let mut recovered_shards_indexes_out: [usize; 1] = [0; 1]; + + unsafe { + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + + assert_eq!(result, 0); // Success + assert_eq!(recovered_shards_out, [1, 2, 3, 4]); + assert_eq!(recovered_shards_indexes_out, [0]); + } + } + + #[test] + fn test_decode_invalid_shard_size() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 5] = [5, 6, 7, 8, 9]; // Invalid shard size. + let original_shards_indexes: [usize; 1] = [1]; + + let recovery_shards: [u8; 4] = [4, 4, 4, 12]; + let recovery_shards_indexes: [usize; 1] = [0]; + + let mut recovered_shards_out: [u8; 4] = [0; 4]; + let mut recovered_shards_indexes_out: [usize; 1] = [0; 1]; + + unsafe { + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + + assert_eq!(result, -1); // Failure due to invalid shard size. + } + } + + #[test] + fn test_decode_mismatched_shard_count() { + let original_shards_count = 3; // Incorrect count. + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 4] = [5, 6, 7, 8]; + let original_shards_indexes: [usize; 1] = [1]; + + let recovery_shards: [u8; 4] = [4, 4, 4, 12]; + let recovery_shards_indexes: [usize; 1] = [0]; + + let mut recovered_shards_out: [u8; 4] = [0; 4]; + let mut recovered_shards_indexes_out: [usize; 1] = [0; 1]; + + unsafe { + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + + assert_eq!(result, -1); // Failure due to mismatched shard count + } + } + + #[test] + fn test_decode_minimum_shards() { + let original_shards_count = 1; + let recovery_shards_count = 1; + let shard_size = 4; + + let recovery_shards: [u8; 4] = [1, 2, 3, 4]; + let recovery_shards_indexes: [usize; 1] = [0]; + + let mut recovered_shards_out: [u8; 4] = [0; 4]; + let mut recovered_shards_indexes_out: [usize; 1] = [0; 1]; + + unsafe { + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + [].as_ptr(), + 0, + [].as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + + assert_eq!(result, 0); // Success. + assert_eq!(recovered_shards_out, [1, 2, 3, 4]); + assert_eq!(recovered_shards_indexes_out, [0]); + } + } + + #[test] + fn test_decode_null_pointer() { + let original_shards_count = 2; + let recovery_shards_count = 1; + let shard_size = 4; + + let original_shards: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8]; + let original_shards_indexes: [usize; 2] = [0, 1]; + + let recovery_shards: [u8; 4] = [9, 10, 11, 12]; + let recovery_shards_indexes: [usize; 1] = [2]; + + let mut recovered_shards_out: [u8; 4] = [0; 4]; + let mut recovered_shards_indexes_out: [usize; 1] = [0; 1]; + + unsafe { + // Test null pointer for original_shards. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + ptr::null(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + assert_eq!(result, -1); + + // Test null pointer for original_shards_indexes. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + ptr::null(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + assert_eq!(result, -1); + + // Test null pointer for recovery_shards. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + ptr::null(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + assert_eq!(result, -1); + + // Test null pointer for recovery_shards_indexes. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + ptr::null(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + assert_eq!(result, -1); + + // Test null pointer for recovered_shards_out. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + ptr::null_mut(), + recovered_shards_out.len(), + recovered_shards_indexes_out.as_mut_ptr(), + ); + assert_eq!(result, -1); + + // Test null pointer for recovered_shards_indexes_out. + let result = reed_solomon_decode( + original_shards_count, + recovery_shards_count, + shard_size, + original_shards.as_ptr(), + original_shards.len(), + original_shards_indexes.as_ptr(), + recovery_shards.as_ptr(), + recovery_shards.len(), + recovery_shards_indexes.as_ptr(), + recovered_shards_out.as_mut_ptr(), + recovered_shards_out.len(), + ptr::null_mut(), + ); + assert_eq!(result, -1); + } + } +} diff --git a/internal/erasurecoding/reedsolomon/reedsolomon.go b/internal/erasurecoding/reedsolomon/reedsolomon.go index 5e62e99..424acf1 100644 --- a/internal/erasurecoding/reedsolomon/reedsolomon.go +++ b/internal/erasurecoding/reedsolomon/reedsolomon.go @@ -17,7 +17,9 @@ var ( recoveryShardsCount C.size_t, shardSize C.size_t, originalShards []byte, + originalShardsLen C.size_t, recoveryShardsOut []byte, + recoveryShardsLen C.size_t, ) (cerr int) reedSolomonDecode func( @@ -30,8 +32,9 @@ var ( recoveryShards []byte, recoveryShardsLen C.size_t, recoveryShardsIndexes []C.size_t, - shardsRecoveredOut []byte, - shardsRecoveredIndexesOut []C.size_t, + recoveredShards []byte, + recoveredShardsLength C.size_t, + recoveredShardsIndexesOut []C.size_t, ) (cerr int) ) @@ -93,9 +96,15 @@ func (r *Encoder) Encode( } shardSize := shardSize(shards) + if shardSize == 0 { + return errors.New("invalid shard size") + } flatOriginalShards := make([]byte, r.originalShardsCount*shardSize) for i, s := range shards[:r.originalShardsCount] { + if len(s) != shardSize { + return errors.New("inconsistent shard size") + } copy(flatOriginalShards[i*shardSize:], s) } @@ -106,7 +115,10 @@ func (r *Encoder) Encode( C.size_t(r.recoveryShardsCount), C.size_t(shardSize), flatOriginalShards, - recoveryShardsOut) + C.size_t(len(flatOriginalShards)), + recoveryShardsOut, + C.size_t(len(recoveryShardsOut)), + ) if result != 0 { return errors.New("unable to encode data") } @@ -129,7 +141,6 @@ func (r *Encoder) Decode(shards [][]byte) error { } shardSize := shardSize(shards) if shardSize == 0 { - // todo better error name return errors.New("invalid shard size") } @@ -137,6 +148,9 @@ func (r *Encoder) Decode(shards [][]byte) error { flatOriginalShardsIndexes := []C.size_t{} for i, s := range shards[:r.originalShardsCount] { if len(s) != 0 { + if len(s) != shardSize { + return errors.New("inconsistent shard size") + } flatOriginalShards = append(flatOriginalShards, s...) flatOriginalShardsIndexes = append(flatOriginalShardsIndexes, C.size_t(i)) } @@ -146,6 +160,9 @@ func (r *Encoder) Decode(shards [][]byte) error { flatRecoveryShardsIndexes := []C.size_t{} for i, s := range shards[r.originalShardsCount:] { if len(s) != 0 { + if len(s) != shardSize { + return errors.New("inconsistent shard size") + } flatRecoveryShards = append(flatRecoveryShards, s...) flatRecoveryShardsIndexes = append(flatRecoveryShardsIndexes, C.size_t(i)) } @@ -168,6 +185,7 @@ func (r *Encoder) Decode(shards [][]byte) error { C.size_t(len(flatRecoveryShards)), flatRecoveryShardsIndexes, restoredShards, + C.size_t(len(restoredShards)), restoredShardsIndexes) if result != 0 { return errors.New("unable to decode data")