Skip to content

Commit

Permalink
Merge pull request #267 from athenavm/keep-program-input-in-single-vec
Browse files Browse the repository at this point in the history
allow reading chunks from the input stream
  • Loading branch information
poszu authored Jan 3, 2025
2 parents 77da089 + 2b4e269 commit 8a10e23
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 174 deletions.
52 changes: 11 additions & 41 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ use crate::utils::Buffer;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

/// Standard input.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone)]
pub struct AthenaStdin {
/// Input stored as a vec of vec of bytes. It's stored this way because the read syscall reads
/// a vec of bytes at a time.
pub buffer: Vec<Vec<u8>>,
pub ptr: usize,
buffer: Vec<u8>,
}

/// Public values for the runner.
Expand All @@ -17,49 +14,26 @@ pub struct AthenaPublicValues {
}

impl AthenaStdin {
/// Create a new `AthenaStdin`.
pub const fn new() -> Self {
Self {
buffer: Vec::new(),
ptr: 0,
}
}

/// Create a `AthenaStdin` from a slice of bytes.
pub fn from(data: &[u8]) -> Self {
Self {
buffer: vec![data.to_vec()],
ptr: 0,
}
}

/// Read a value from the buffer.
pub fn read<T: DeserializeOwned>(&mut self) -> T {
let result: T = bincode::deserialize(&self.buffer[self.ptr]).expect("failed to deserialize");
self.ptr += 1;
result
}

/// Read a slice of bytes from the buffer.
pub fn read_slice(&mut self, slice: &mut [u8]) {
slice.copy_from_slice(&self.buffer[self.ptr]);
self.ptr += 1;
Self { buffer: Vec::new() }
}

/// Write a value to the buffer.
pub fn write<T: Serialize>(&mut self, data: &T) {
let mut tmp = Vec::new();
bincode::serialize_into(&mut tmp, data).expect("serialization failed");
self.buffer.push(tmp);
bincode::serialize_into(&mut self.buffer, data).expect("serialization failed");
}

/// Write a slice of bytes to the buffer.
pub fn write_slice(&mut self, slice: &[u8]) {
self.buffer.push(slice.to_vec());
self.buffer.extend_from_slice(slice);
}

pub fn write_vec(&mut self, vec: Vec<u8>) {
self.buffer.push(vec);
pub fn write_vec(&mut self, mut vec: Vec<u8>) {
self.buffer.append(&mut vec);
}

pub fn to_vec(self) -> Vec<u8> {
self.buffer
}
}

Expand All @@ -71,10 +45,6 @@ impl AthenaPublicValues {
}
}

pub fn raw(&self) -> String {
format!("0x{}", hex::encode(self.buffer.data.clone()))
}

/// Create a `AthenaPublicValues` from a slice of bytes.
pub fn from(data: &[u8]) -> Self {
Self {
Expand Down
12 changes: 4 additions & 8 deletions core/src/runtime/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ impl Read for Runtime<'_> {

impl Runtime<'_> {
pub fn write_stdin<U: Serialize>(&mut self, input: &U) {
let mut buf = Vec::new();
bincode::serialize_into(&mut buf, input).expect("serialization failed");
self.state.input_stream.push(buf);
bincode::serialize_into(&mut self.state.input_stream, input).expect("serialization failed");
}

pub fn write_stdin_slice(&mut self, input: &[u8]) {
self.state.input_stream.push(input.to_vec());
self.write_from(input.iter().copied());
}

pub fn write_vecs(&mut self, inputs: &[Vec<u8>]) {
for input in inputs {
self.state.input_stream.push(input.clone());
}
pub fn write_from<T: IntoIterator<Item = u8>>(&mut self, input: T) {
self.state.input_stream.extend(input);
}

pub fn read_public_values<U: DeserializeOwned>(&mut self) -> U {
Expand Down
2 changes: 1 addition & 1 deletion core/src/runtime/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct ExecutionState {
pub uninitialized_memory: HashMap<u32, u32, BuildNoHashHasher<u32>>,

/// A stream of input values (global to the entire program).
pub input_stream: Vec<Vec<u8>>,
pub input_stream: Vec<u8>,

/// A ptr to the current position in the input stream incremented by HINT_READ opcode.
pub input_stream_ptr: usize,
Expand Down
15 changes: 11 additions & 4 deletions core/src/runtime/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,19 @@ impl<'a, 'h> SyscallContext<'a, 'h> {
#[tracing::instrument(skip(self))]
pub fn bytes(&self, mut addr: u32, len: usize) -> Vec<u8> {
let mut bytes = Vec::new();
let mut bytes_to_read = len;

// handle case when addr is not aligned to 4B
let addr_offset = (addr % 4) as usize;
let addr_offset = addr % 4;
if addr_offset != 0 {
let word = self.word(addr - addr_offset as u32).to_le_bytes();
bytes.extend_from_slice(&word[addr_offset..]);
tracing::debug!(addr, len, addr_offset, "addr not aligned");
let word = self.word(addr - addr_offset).to_le_bytes();
bytes.extend_from_slice(&word[addr_offset as usize..]);
addr += bytes.len() as u32;
bytes_to_read = bytes_to_read.saturating_sub(bytes.len());
}

for addr in (addr..addr + (len - bytes.len()) as u32).step_by(4) {
for addr in (addr..addr + bytes_to_read as u32).step_by(4) {
bytes.extend_from_slice(&self.word(addr).to_le_bytes());
}
bytes.truncate(len); // handle case when len is not a multiple of 4
Expand Down Expand Up @@ -290,5 +294,8 @@ mod tests {
// address not aligned and length not a multiple of 4
let read = ctx.bytes(0x103, 59);
assert_eq!(read, memory[3..3 + 59]);

let read = ctx.bytes(0x1, 2);
assert_eq!(read, memory[1..1 + 2]);
}
}
174 changes: 122 additions & 52 deletions core/src/syscall/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ pub(crate) struct SyscallHintLen;

impl Syscall for SyscallHintLen {
fn execute(&self, ctx: &mut SyscallContext, _: u32, _: u32) -> SyscallResult {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
tracing::debug!(
"no more data to read in stdin: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len(),
);
return Ok(Outcome::Result(Some(0)));
}
Ok(Outcome::Result(Some(
ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr].len() as u32,
)))
let len = if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
0
} else {
ctx.rt.state.input_stream.len() - ctx.rt.state.input_stream_ptr
};
tracing::debug!(
ptr = ctx.rt.state.input_stream_ptr,
total = ctx.rt.state.input_stream.len(),
len,
"hinted remaning data in the input stream"
);
Ok(Outcome::Result(Some(len as u32)))
}
}

Expand All @@ -26,59 +27,82 @@ pub(crate) struct SyscallHintRead;

impl Syscall for SyscallHintRead {
fn execute(&self, ctx: &mut SyscallContext, ptr: u32, len: u32) -> SyscallResult {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
if ctx.rt.unconstrained {
tracing::error!("hint read should not be used in a unconstrained block");
return Err(StatusCode::StaticModeViolation);
}
let data = ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr..].to_vec();
if len as usize > data.len() {
tracing::debug!(
"failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len()
ptr = ctx.rt.state.input_stream_ptr,
total = ctx.rt.state.input_stream.len(),
available = data.len(),
len,
"failed reading stdin due to insufficient input data",
);
return Err(StatusCode::InsufficientInput);
}
let vec = &ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr];
ctx.rt.state.input_stream_ptr += 1;
assert!(
!ctx.rt.unconstrained,
"hint read should not be used in a unconstrained block"
);
if vec.len() != len as usize {
let mut data = &data[..len as usize];
let mut address = ptr;

// Handle unaligned start
if address % 4 != 0 {
let aligned_addr = address & !3; // Round down to aligned address
let offset = (address % 4) as usize;
let bytes_to_write = std::cmp::min(4 - offset, data.len());
tracing::debug!(
"hint input stream read length mismatch: expected={}, actual={}",
len,
vec.len()
address,
aligned_addr,
offset,
bytes_to_write,
"hint read address not aligned to 4 bytes"
);
return Err(StatusCode::InvalidSyscallArgument);

let mut word_bytes = ctx.rt.mr(aligned_addr).to_le_bytes();
tracing::debug!(word = hex::encode(word_bytes), "read existing word");

word_bytes[offset..offset + bytes_to_write].copy_from_slice(&data[..bytes_to_write]);

ctx.rt.mw(aligned_addr, u32::from_le_bytes(word_bytes));
tracing::debug!(word = hex::encode(word_bytes), "written updated word");

address = aligned_addr + 4;
data = &data[bytes_to_write..];
}
if ptr % 4 != 0 {
tracing::debug!("hint read address not aligned to 4 bytes");
return Err(StatusCode::InvalidSyscallArgument);

// Iterate through the remaining data in 4-byte chunks
let mut chunks = data.chunks_exact(4);
for chunk in &mut chunks {
// unwrap() won't panic, which is guaranteed by chunks()
let word = u32::from_le_bytes(chunk.try_into().unwrap());
ctx.rt.mw(address, word);
address += 4;
}
// Iterate through the vec in 4-byte chunks
for i in (0..len).step_by(4) {
// Get each byte in the chunk
let b1 = vec[i as usize];
// In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we
// are assuming the word is uninitialized, so filling it with 0s makes sense.
let b2 = vec.get(i as usize + 1).copied().unwrap_or(0);
let b3 = vec.get(i as usize + 2).copied().unwrap_or(0);
let b4 = vec.get(i as usize + 3).copied().unwrap_or(0);
let word = u32::from_le_bytes([b1, b2, b3, b4]);

// Save the data into runtime state so the runtime will use the desired data instead of
// 0 when first reading/writing from this address.
ctx
.rt
.state
.uninitialized_memory
.entry(ptr + i)
.and_modify(|_| panic!("hint read address is initialized already"))
.or_insert(word);
// In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we
// are assuming the word is uninitialized, so filling it with 0s makes sense.
let remainder = chunks.remainder();
if !remainder.is_empty() {
let mut word_array = [0u8; 4];
let len = remainder.len();
word_array[..len].copy_from_slice(remainder);
ctx.rt.mw(address, u32::from_le_bytes(word_array));
}
Ok(Outcome::Result(None))
tracing::debug!(
from = ptr,
to = address as usize + remainder.len(),
read = len,
"HintRead syscall finished"
);
tracing::trace!(data = hex::encode(data));
ctx.rt.state.input_stream_ptr += len as usize;
Ok(Outcome::Result(Some(len)))
}
}

#[cfg(test)]
mod tests {
use athena_interface::StatusCode;

use crate::{
runtime::{Outcome, Program, Runtime, Syscall, SyscallContext},
utils::AthenaCoreOpts,
Expand All @@ -96,9 +120,55 @@ mod tests {

// with inputs
let data = [vec![1, 2, 3, 4, 5], vec![6, 7]];
ctx.rt.write_vecs(&data);
ctx.rt.write_stdin_slice(&data[0]);
ctx.rt.write_stdin_slice(&data[1]);

let result = syscall.execute(&mut ctx, 0, 0).unwrap();
assert_eq!(Outcome::Result(Some(data[0].len() as u32)), result);
assert_eq!(
Outcome::Result(Some((data[0].len() + data[1].len()) as u32)),
result
);
}

#[test]
fn hint_read_cant_run_in_unconstrained() {
let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None);
rt.unconstrained = true;
let mut ctx = SyscallContext::new(&mut rt);
let syscall = super::SyscallHintRead {};

let result = syscall.execute(&mut ctx, 0, 0);
assert_eq!(Err(StatusCode::StaticModeViolation), result);
}

#[test]
fn hint_read() {
let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None);
let mut ctx = SyscallContext::new(&mut rt);
let syscall = super::SyscallHintRead {};

// no inputs
let result = syscall.execute(&mut ctx, 0, 10);
assert_eq!(Err(StatusCode::InsufficientInput), result);

let data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
ctx.rt.write_stdin(&data);

// can't read more than available
let result = syscall.execute(&mut ctx, 0, data.len() as u32 + 1);
assert_eq!(Err(StatusCode::InsufficientInput), result);

// read only up to `len`
let len = 3;
let result = syscall.execute(&mut ctx, 0, len as u32);
assert_eq!(Ok(Outcome::Result(Some(len as u32))), result);
assert_eq!(&data[..len], ctx.bytes(0, len).as_slice());

// read the rest
let address = len;
let len = data.len() - len;
let result = syscall.execute(&mut ctx, address as u32, len as u32);
assert_eq!(Ok(Outcome::Result(Some(len as u32))), result);
assert_eq!(data, ctx.bytes(0, data.len()).as_slice());
}
}
9 changes: 4 additions & 5 deletions core/src/syscall/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ impl Syscall for SyscallWrite {
}
}
3 => {
rt.state.public_values_stream.extend_from_slice(&bytes);
rt.state.public_values_stream.extend(bytes);
}
4 => {
rt.state.input_stream.push(bytes);
rt.state.input_stream.extend(bytes);
}
fd => {
tracing::debug!(fd, "executing hook");
match rt.execute_hook(fd, &bytes) {
Ok(result) => {
rt.state.input_stream.push(result);
rt.state.input_stream.extend(result);
}
Err(err) => {
tracing::debug!(fd, ?err, "hook failed");
Expand Down Expand Up @@ -115,7 +115,6 @@ mod tests {

let result = SyscallWrite {}.execute(&mut SyscallContext { rt: &mut runtime }, 7, 0);
result.unwrap();
let result = runtime.state.input_stream.pop().unwrap();
assert_eq!(vec![1, 2, 3, 4, 5], result);
assert_eq!(vec![1, 2, 3, 4, 5], runtime.state.input_stream);
}
}
Loading

0 comments on commit 8a10e23

Please sign in to comment.