From 0ffb3bdaf63c791e397e19e069debd2d7fb41ea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Wed, 18 Dec 2024 13:32:40 +0100 Subject: [PATCH 1/2] allow reading chunks from the input stream Syscall HintRead reads `len` bytes and keeps remaining data for later reads. --- core/src/io.rs | 59 ++++---- core/src/runtime/io.rs | 12 +- core/src/runtime/state.rs | 2 +- core/src/runtime/syscall.rs | 15 +- core/src/syscall/hint.rs | 174 +++++++++++++++------- core/src/syscall/write.rs | 4 +- examples/Cargo.lock | 8 +- examples/wallet/script/src/bin/execute.rs | 9 +- sdk/src/lib.rs | 8 +- tests/entrypoint/src/main.rs | 5 +- vm/entrypoint/src/io.rs | 66 ++------ 11 files changed, 197 insertions(+), 165 deletions(-) diff --git a/core/src/io.rs b/core/src/io.rs index 5e1d4e87..2c31ce4f 100644 --- a/core/src/io.rs +++ b/core/src/io.rs @@ -1,13 +1,12 @@ +use std::{cmp::min, collections::VecDeque, io::Write}; + use crate::utils::Buffer; use serde::{de::DeserializeOwned, Deserialize, Serialize}; /// Standard input. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AthenaStdin { - /// Input stored as a vec of vec of bytes. It's stored this way because the read syscall reads - /// a vec of bytes at a time. - pub buffer: Vec>, - pub ptr: usize, + pub buffer: VecDeque, } /// Public values for the runner. @@ -17,49 +16,55 @@ pub struct AthenaPublicValues { } impl AthenaStdin { - /// Create a new `AthenaStdin`. pub const fn new() -> Self { Self { - buffer: Vec::new(), - ptr: 0, + buffer: VecDeque::new(), } } /// Create a `AthenaStdin` from a slice of bytes. pub fn from(data: &[u8]) -> Self { - Self { - buffer: vec![data.to_vec()], - ptr: 0, - } - } - - /// Read a value from the buffer. - pub fn read(&mut self) -> T { - let result: T = bincode::deserialize(&self.buffer[self.ptr]).expect("failed to deserialize"); - self.ptr += 1; - result + let mut stdin = Self::new(); + stdin.write_slice(data); + stdin } /// Read a slice of bytes from the buffer. + /// Reads up to slice.len() bytes from the beginning of the buffer into the provided slice. pub fn read_slice(&mut self, slice: &mut [u8]) { - slice.copy_from_slice(&self.buffer[self.ptr]); - self.ptr += 1; + let bytes_to_read = min(slice.len(), self.buffer.len()); + if bytes_to_read == 0 { + return; + } + + // Get the two contiguous slices from the VecDeque + let (first, second) = self.buffer.as_slices(); + + // Copy from the first slice + let first_copy = min(first.len(), bytes_to_read); + slice[..first_copy].copy_from_slice(&first[..first_copy]); + + // If we need more bytes and there's a second slice, copy from it + if first_copy < bytes_to_read { + let second_copy = bytes_to_read - first_copy; + slice[first_copy..bytes_to_read].copy_from_slice(&second[..second_copy]); + } + + self.buffer.drain(..bytes_to_read); } /// Write a value to the buffer. pub fn write(&mut self, data: &T) { - let mut tmp = Vec::new(); - bincode::serialize_into(&mut tmp, data).expect("serialization failed"); - self.buffer.push(tmp); + bincode::serialize_into(&mut self.buffer, data).expect("serialization failed"); } /// Write a slice of bytes to the buffer. pub fn write_slice(&mut self, slice: &[u8]) { - self.buffer.push(slice.to_vec()); + self.buffer.write_all(slice).expect("pushing to buffer"); } pub fn write_vec(&mut self, vec: Vec) { - self.buffer.push(vec); + self.write_slice(&vec); } } @@ -71,10 +76,6 @@ impl AthenaPublicValues { } } - pub fn raw(&self) -> String { - format!("0x{}", hex::encode(self.buffer.data.clone())) - } - /// Create a `AthenaPublicValues` from a slice of bytes. pub fn from(data: &[u8]) -> Self { Self { diff --git a/core/src/runtime/io.rs b/core/src/runtime/io.rs index d2ba640b..05d4a4e6 100644 --- a/core/src/runtime/io.rs +++ b/core/src/runtime/io.rs @@ -14,19 +14,15 @@ impl Read for Runtime<'_> { impl Runtime<'_> { pub fn write_stdin(&mut self, input: &U) { - let mut buf = Vec::new(); - bincode::serialize_into(&mut buf, input).expect("serialization failed"); - self.state.input_stream.push(buf); + bincode::serialize_into(&mut self.state.input_stream, input).expect("serialization failed"); } pub fn write_stdin_slice(&mut self, input: &[u8]) { - self.state.input_stream.push(input.to_vec()); + self.write_from(input.iter().copied()); } - pub fn write_vecs(&mut self, inputs: &[Vec]) { - for input in inputs { - self.state.input_stream.push(input.clone()); - } + pub fn write_from>(&mut self, input: T) { + self.state.input_stream.extend(input); } pub fn read_public_values(&mut self) -> U { diff --git a/core/src/runtime/state.rs b/core/src/runtime/state.rs index e79680af..11939921 100644 --- a/core/src/runtime/state.rs +++ b/core/src/runtime/state.rs @@ -23,7 +23,7 @@ pub struct ExecutionState { pub uninitialized_memory: HashMap>, /// A stream of input values (global to the entire program). - pub input_stream: Vec>, + pub input_stream: Vec, /// A ptr to the current position in the input stream incremented by HINT_READ opcode. pub input_stream_ptr: usize, diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index e8fd1fb1..096e0d3e 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -158,15 +158,19 @@ impl<'a, 'h> SyscallContext<'a, 'h> { #[tracing::instrument(skip(self))] pub fn bytes(&self, mut addr: u32, len: usize) -> Vec { let mut bytes = Vec::new(); + let mut bytes_to_read = len; + // handle case when addr is not aligned to 4B - let addr_offset = (addr % 4) as usize; + let addr_offset = addr % 4; if addr_offset != 0 { - let word = self.word(addr - addr_offset as u32).to_le_bytes(); - bytes.extend_from_slice(&word[addr_offset..]); + tracing::debug!(addr, len, addr_offset, "addr not aligned"); + let word = self.word(addr - addr_offset).to_le_bytes(); + bytes.extend_from_slice(&word[addr_offset as usize..]); addr += bytes.len() as u32; + bytes_to_read = bytes_to_read.saturating_sub(bytes.len()); } - for addr in (addr..addr + (len - bytes.len()) as u32).step_by(4) { + for addr in (addr..addr + bytes_to_read as u32).step_by(4) { bytes.extend_from_slice(&self.word(addr).to_le_bytes()); } bytes.truncate(len); // handle case when len is not a multiple of 4 @@ -290,5 +294,8 @@ mod tests { // address not aligned and length not a multiple of 4 let read = ctx.bytes(0x103, 59); assert_eq!(read, memory[3..3 + 59]); + + let read = ctx.bytes(0x1, 2); + assert_eq!(read, memory[1..1 + 2]); } } diff --git a/core/src/syscall/hint.rs b/core/src/syscall/hint.rs index c5c7a622..2184c9e0 100644 --- a/core/src/syscall/hint.rs +++ b/core/src/syscall/hint.rs @@ -7,17 +7,18 @@ pub(crate) struct SyscallHintLen; impl Syscall for SyscallHintLen { fn execute(&self, ctx: &mut SyscallContext, _: u32, _: u32) -> SyscallResult { - if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() { - tracing::debug!( - "no more data to read in stdin: input_stream_ptr={}, input_stream_len={}", - ctx.rt.state.input_stream_ptr, - ctx.rt.state.input_stream.len(), - ); - return Ok(Outcome::Result(Some(0))); - } - Ok(Outcome::Result(Some( - ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr].len() as u32, - ))) + let len = if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() { + 0 + } else { + ctx.rt.state.input_stream.len() - ctx.rt.state.input_stream_ptr + }; + tracing::debug!( + ptr = ctx.rt.state.input_stream_ptr, + total = ctx.rt.state.input_stream.len(), + len, + "hinted remaning data in the input stream" + ); + Ok(Outcome::Result(Some(len as u32))) } } @@ -26,59 +27,82 @@ pub(crate) struct SyscallHintRead; impl Syscall for SyscallHintRead { fn execute(&self, ctx: &mut SyscallContext, ptr: u32, len: u32) -> SyscallResult { - if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() { + if ctx.rt.unconstrained { + tracing::error!("hint read should not be used in a unconstrained block"); + return Err(StatusCode::StaticModeViolation); + } + let data = ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr..].to_vec(); + if len as usize > data.len() { tracing::debug!( - "failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}", - ctx.rt.state.input_stream_ptr, - ctx.rt.state.input_stream.len() + ptr = ctx.rt.state.input_stream_ptr, + total = ctx.rt.state.input_stream.len(), + available = data.len(), + len, + "failed reading stdin due to insufficient input data", ); return Err(StatusCode::InsufficientInput); } - let vec = &ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr]; - ctx.rt.state.input_stream_ptr += 1; - assert!( - !ctx.rt.unconstrained, - "hint read should not be used in a unconstrained block" - ); - if vec.len() != len as usize { + let mut data = &data[..len as usize]; + let mut address = ptr; + + // Handle unaligned start + if address % 4 != 0 { + let aligned_addr = address & !3; // Round down to aligned address + let offset = (address % 4) as usize; + let bytes_to_write = std::cmp::min(4 - offset, data.len()); tracing::debug!( - "hint input stream read length mismatch: expected={}, actual={}", - len, - vec.len() + address, + aligned_addr, + offset, + bytes_to_write, + "hint read address not aligned to 4 bytes" ); - return Err(StatusCode::InvalidSyscallArgument); + + let mut word_bytes = ctx.rt.mr(aligned_addr).to_le_bytes(); + tracing::debug!(word = hex::encode(word_bytes), "read existing word"); + + word_bytes[offset..offset + bytes_to_write].copy_from_slice(&data[..bytes_to_write]); + + ctx.rt.mw(aligned_addr, u32::from_le_bytes(word_bytes)); + tracing::debug!(word = hex::encode(word_bytes), "written updated word"); + + address = aligned_addr + 4; + data = &data[bytes_to_write..]; } - if ptr % 4 != 0 { - tracing::debug!("hint read address not aligned to 4 bytes"); - return Err(StatusCode::InvalidSyscallArgument); + + // Iterate through the remaining data in 4-byte chunks + let mut chunks = data.chunks_exact(4); + for chunk in &mut chunks { + // unwrap() won't panic, which is guaranteed by chunks() + let word = u32::from_le_bytes(chunk.try_into().unwrap()); + ctx.rt.mw(address, word); + address += 4; } - // Iterate through the vec in 4-byte chunks - for i in (0..len).step_by(4) { - // Get each byte in the chunk - let b1 = vec[i as usize]; - // In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we - // are assuming the word is uninitialized, so filling it with 0s makes sense. - let b2 = vec.get(i as usize + 1).copied().unwrap_or(0); - let b3 = vec.get(i as usize + 2).copied().unwrap_or(0); - let b4 = vec.get(i as usize + 3).copied().unwrap_or(0); - let word = u32::from_le_bytes([b1, b2, b3, b4]); - - // Save the data into runtime state so the runtime will use the desired data instead of - // 0 when first reading/writing from this address. - ctx - .rt - .state - .uninitialized_memory - .entry(ptr + i) - .and_modify(|_| panic!("hint read address is initialized already")) - .or_insert(word); + // In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we + // are assuming the word is uninitialized, so filling it with 0s makes sense. + let remainder = chunks.remainder(); + if !remainder.is_empty() { + let mut word_array = [0u8; 4]; + let len = remainder.len(); + word_array[..len].copy_from_slice(remainder); + ctx.rt.mw(address, u32::from_le_bytes(word_array)); } - Ok(Outcome::Result(None)) + tracing::debug!( + from = ptr, + to = address as usize + remainder.len(), + read = len, + "HintRead syscall finished" + ); + tracing::trace!(data = hex::encode(data)); + ctx.rt.state.input_stream_ptr += len as usize; + Ok(Outcome::Result(Some(len))) } } #[cfg(test)] mod tests { + use athena_interface::StatusCode; + use crate::{ runtime::{Outcome, Program, Runtime, Syscall, SyscallContext}, utils::AthenaCoreOpts, @@ -96,9 +120,55 @@ mod tests { // with inputs let data = [vec![1, 2, 3, 4, 5], vec![6, 7]]; - ctx.rt.write_vecs(&data); + ctx.rt.write_stdin_slice(&data[0]); + ctx.rt.write_stdin_slice(&data[1]); let result = syscall.execute(&mut ctx, 0, 0).unwrap(); - assert_eq!(Outcome::Result(Some(data[0].len() as u32)), result); + assert_eq!( + Outcome::Result(Some((data[0].len() + data[1].len()) as u32)), + result + ); + } + + #[test] + fn hint_read_cant_run_in_unconstrained() { + let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None); + rt.unconstrained = true; + let mut ctx = SyscallContext::new(&mut rt); + let syscall = super::SyscallHintRead {}; + + let result = syscall.execute(&mut ctx, 0, 0); + assert_eq!(Err(StatusCode::StaticModeViolation), result); + } + + #[test] + fn hint_read() { + let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None); + let mut ctx = SyscallContext::new(&mut rt); + let syscall = super::SyscallHintRead {}; + + // no inputs + let result = syscall.execute(&mut ctx, 0, 10); + assert_eq!(Err(StatusCode::InsufficientInput), result); + + let data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + ctx.rt.write_stdin(&data); + + // can't read more than available + let result = syscall.execute(&mut ctx, 0, data.len() as u32 + 1); + assert_eq!(Err(StatusCode::InsufficientInput), result); + + // read only up to `len` + let len = 3; + let result = syscall.execute(&mut ctx, 0, len as u32); + assert_eq!(Ok(Outcome::Result(Some(len as u32))), result); + assert_eq!(&data[..len], ctx.bytes(0, len).as_slice()); + + // read the rest + let address = len; + let len = data.len() - len; + let result = syscall.execute(&mut ctx, address as u32, len as u32); + assert_eq!(Ok(Outcome::Result(Some(len as u32))), result); + assert_eq!(data, ctx.bytes(0, data.len()).as_slice()); } } diff --git a/core/src/syscall/write.rs b/core/src/syscall/write.rs index bb0bbd5b..80eabc50 100644 --- a/core/src/syscall/write.rs +++ b/core/src/syscall/write.rs @@ -21,7 +21,7 @@ impl Syscall for SyscallWrite { let rt = &mut ctx.rt; let nbytes = rt.register(Register::X12); // Read nbytes from memory starting at write_buf. - let bytes = (0..nbytes) + let mut bytes = (0..nbytes) .map(|i| rt.byte(write_buf + i)) .collect::>(); match fd { @@ -72,7 +72,7 @@ impl Syscall for SyscallWrite { rt.state.public_values_stream.extend_from_slice(&bytes); } 4 => { - rt.state.input_stream.push(bytes); + rt.state.input_stream.append(&mut bytes); } fd => { tracing::debug!("syscall write called with invalid fd: {fd}"); diff --git a/examples/Cargo.lock b/examples/Cargo.lock index e79d4947..e42e199e 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -1472,18 +1472,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" +checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" +checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" dependencies = [ "proc-macro2", "quote", diff --git a/examples/wallet/script/src/bin/execute.rs b/examples/wallet/script/src/bin/execute.rs index 02797dcd..8d4fb9fa 100644 --- a/examples/wallet/script/src/bin/execute.rs +++ b/examples/wallet/script/src/bin/execute.rs @@ -44,10 +44,11 @@ fn spawn(host: &mut MockHost, owner: &Pubkey) -> Result> let method_selector = MethodSelector::from("athexp_spawn"); + let max = 10_000; let client = ExecutionClient::new(); - let (mut result, _) = - client.execute_function(ELF, &method_selector, stdin, Some(host), None, None)?; - + let (mut result, gas) = + client.execute_function(ELF, &method_selector, stdin, Some(host), Some(max), None)?; + dbg!(max - gas.unwrap()); Ok(result.read()) } @@ -81,7 +82,7 @@ fn main() { amount: 120, }; - stdin.write_vec(wallet.clone()); + stdin.write_slice(wallet); stdin.write_vec(args.encode()); let alice_balance = host.get_balance(&ADDRESS_ALICE); diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index 43b6f538..f123113d 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -63,7 +63,7 @@ impl ExecutionClient { } }; let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_vecs(&stdin.buffer); + runtime.write_from(stdin.buffer); runtime.execute().map(|gas_left| { ( AthenaPublicValues::from(&runtime.state.public_values_stream), @@ -90,7 +90,7 @@ impl ExecutionClient { } }; let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_vecs(&stdin.buffer); + runtime.write_from(stdin.buffer); runtime .execute_function_by_selector(selector) .map(|gas_left| { @@ -133,7 +133,7 @@ impl ExecutionClient { }; let listener = std::net::TcpListener::bind(format!("127.0.0.1:{gdb_port}")).unwrap(); let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_vecs(&stdin.buffer); + runtime.write_from(stdin.buffer); runtime.initialize(); athena_core::runtime::gdbstub::run_under_gdb(&mut runtime, listener, None).map(|gas_left| { ( @@ -177,7 +177,7 @@ impl ExecutionClient { }; let listener = std::net::TcpListener::bind(format!("127.0.0.1:{gdb_port}")).unwrap(); let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_vecs(&stdin.buffer); + runtime.write_from(stdin.buffer); runtime.initialize(); athena_core::runtime::gdbstub::run_under_gdb(&mut runtime, listener, Some(function)).map( |gas_left| { diff --git a/tests/entrypoint/src/main.rs b/tests/entrypoint/src/main.rs index e53b90ab..98c17788 100644 --- a/tests/entrypoint/src/main.rs +++ b/tests/entrypoint/src/main.rs @@ -15,9 +15,8 @@ athena_vm::entrypoint!(); impl EntrypointTest { #[callable] fn test1() { - let input = athena_vm::io::read_vec(); - let address = - bincode::deserialize(&input).expect("input address malformed, failed to deserialize"); + let address = bincode::deserialize_from(athena_vm::io::Io::default()) + .expect("input address malformed, failed to deserialize"); // recursive call to self call(address, None, Some("athexp_test2"), 0); diff --git a/vm/entrypoint/src/io.rs b/vm/entrypoint/src/io.rs index 0f40cad1..8b1a1895 100644 --- a/vm/entrypoint/src/io.rs +++ b/vm/entrypoint/src/io.rs @@ -2,9 +2,7 @@ use crate::syscalls::syscall_write; use crate::syscalls::{syscall_hint_len, syscall_hint_read}; use serde::de::DeserializeOwned; use serde::Serialize; -use std::alloc::Layout; -use std::io::Result; -use std::io::Write; +use std::io::{Result, Write}; /// The file descriptor for public values. pub const FD_PUBLIC_VALUES: u32 = 3; @@ -30,41 +28,6 @@ impl Write for SyscallWriter { } } -/// Read a buffer from the input stream. -/// -/// ### Examples -/// ```ignore -/// let data: Vec = athena_vm::io::read_vec(); -/// ``` -pub fn read_vec() -> Vec { - // Round up to the nearest multiple of 4 so that the memory allocated is in whole words - let len = syscall_hint_len(); - if len == 0 { - return vec![]; - } - let capacity = (len + 3) / 4 * 4; - - // Allocate a buffer of the required length that is 4 byte aligned - let layout = Layout::from_size_align(capacity, 4).expect("vec is too large"); - let ptr = unsafe { std::alloc::alloc(layout) }; - - // SAFETY: - // 1. `ptr` was allocated using alloc - // 2. We assuume that the VM global allocator doesn't dealloc - // 3/6. Size is correct from above - // 4/5. Length is 0 - // 7. Layout::from_size_align already checks this - let mut vec = unsafe { Vec::from_raw_parts(ptr, 0, capacity) }; - - // Read the vec into uninitialized memory. The syscall assumes the memory is uninitialized, - // which should be true because the allocator does not dealloc, so a new alloc should be fresh. - unsafe { - syscall_hint_read(ptr, len); - vec.set_len(len); - } - vec -} - /// Read a deserializable object from the input stream. /// /// ### Examples @@ -80,8 +43,7 @@ pub fn read_vec() -> Vec { /// let data: MyStruct = athena_vm::io::read(); /// ``` pub fn read() -> T { - let vec = read_vec(); - bincode::deserialize(&vec).expect("deserialization failed") + bincode::deserialize_from(Io::default()).unwrap() } /// Write a serializable object to the public values stream. @@ -106,7 +68,7 @@ pub fn write(value: &T) { let writer = SyscallWriter { fd: FD_PUBLIC_VALUES, }; - bincode::serialize_into(writer, value).expect("serialization failed"); + bincode::serialize_into(writer, value).unwrap(); } /// Write bytes to the public values stream. @@ -143,7 +105,7 @@ pub fn write_slice(buf: &[u8]) { /// ``` pub fn hint(value: &T) { let writer = SyscallWriter { fd: FD_HINT }; - bincode::serialize_into(writer, value).expect("serialization failed"); + bincode::serialize_into(writer, value).unwrap(); } /// Hint bytes to the hint stream. @@ -159,27 +121,23 @@ pub fn hint_slice(buf: &[u8]) { } #[derive(Default)] -pub struct Io { - /// The remaining bytes to be read from the input stream. - /// The IO only supports reading whole 'lines' at a time. - /// We cache the remaining bytes that `read()` has not consumed yet - /// for future reads. This is important for cases when `read()` doesn't - /// consume the whole line. - read_remainder: std::collections::VecDeque, -} +pub struct Io {} impl std::io::Read for Io { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - if self.read_remainder.is_empty() { - self.read_remainder.extend(crate::io::read_vec()); + let len = std::cmp::min(buf.len(), syscall_hint_len()); + if len == 0 { + return Ok(0); } - self.read_remainder.read(buf) + + syscall_hint_read(buf.as_mut_ptr(), len); + Ok(len) } } impl std::io::Write for Io { fn write(&mut self, buf: &[u8]) -> std::io::Result { - crate::io::write_slice(buf); + write_slice(buf); Ok(buf.len()) } From aa0c90f4bb95e814afcf9f2aa3f2857575f785c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 19 Dec 2024 09:05:12 +0100 Subject: [PATCH 2/2] Simplify AthenaStdin --- core/src/io.rs | 51 ++++++++++---------------------------------------- sdk/src/lib.rs | 8 ++++---- 2 files changed, 14 insertions(+), 45 deletions(-) diff --git a/core/src/io.rs b/core/src/io.rs index 2c31ce4f..ff6fafe4 100644 --- a/core/src/io.rs +++ b/core/src/io.rs @@ -1,12 +1,10 @@ -use std::{cmp::min, collections::VecDeque, io::Write}; - use crate::utils::Buffer; use serde::{de::DeserializeOwned, Deserialize, Serialize}; /// Standard input. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct AthenaStdin { - pub buffer: VecDeque, + buffer: Vec, } /// Public values for the runner. @@ -17,40 +15,7 @@ pub struct AthenaPublicValues { impl AthenaStdin { pub const fn new() -> Self { - Self { - buffer: VecDeque::new(), - } - } - - /// Create a `AthenaStdin` from a slice of bytes. - pub fn from(data: &[u8]) -> Self { - let mut stdin = Self::new(); - stdin.write_slice(data); - stdin - } - - /// Read a slice of bytes from the buffer. - /// Reads up to slice.len() bytes from the beginning of the buffer into the provided slice. - pub fn read_slice(&mut self, slice: &mut [u8]) { - let bytes_to_read = min(slice.len(), self.buffer.len()); - if bytes_to_read == 0 { - return; - } - - // Get the two contiguous slices from the VecDeque - let (first, second) = self.buffer.as_slices(); - - // Copy from the first slice - let first_copy = min(first.len(), bytes_to_read); - slice[..first_copy].copy_from_slice(&first[..first_copy]); - - // If we need more bytes and there's a second slice, copy from it - if first_copy < bytes_to_read { - let second_copy = bytes_to_read - first_copy; - slice[first_copy..bytes_to_read].copy_from_slice(&second[..second_copy]); - } - - self.buffer.drain(..bytes_to_read); + Self { buffer: Vec::new() } } /// Write a value to the buffer. @@ -60,11 +25,15 @@ impl AthenaStdin { /// Write a slice of bytes to the buffer. pub fn write_slice(&mut self, slice: &[u8]) { - self.buffer.write_all(slice).expect("pushing to buffer"); + self.buffer.extend_from_slice(slice); + } + + pub fn write_vec(&mut self, mut vec: Vec) { + self.buffer.append(&mut vec); } - pub fn write_vec(&mut self, vec: Vec) { - self.write_slice(&vec); + pub fn to_vec(self) -> Vec { + self.buffer } } diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index f123113d..05bdb47b 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -63,7 +63,7 @@ impl ExecutionClient { } }; let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_from(stdin.buffer); + runtime.write_from(stdin.to_vec()); runtime.execute().map(|gas_left| { ( AthenaPublicValues::from(&runtime.state.public_values_stream), @@ -90,7 +90,7 @@ impl ExecutionClient { } }; let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_from(stdin.buffer); + runtime.write_from(stdin.to_vec()); runtime .execute_function_by_selector(selector) .map(|gas_left| { @@ -133,7 +133,7 @@ impl ExecutionClient { }; let listener = std::net::TcpListener::bind(format!("127.0.0.1:{gdb_port}")).unwrap(); let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_from(stdin.buffer); + runtime.write_from(stdin.to_vec()); runtime.initialize(); athena_core::runtime::gdbstub::run_under_gdb(&mut runtime, listener, None).map(|gas_left| { ( @@ -177,7 +177,7 @@ impl ExecutionClient { }; let listener = std::net::TcpListener::bind(format!("127.0.0.1:{gdb_port}")).unwrap(); let mut runtime = Runtime::new(program, host, opts, context); - runtime.write_from(stdin.buffer); + runtime.write_from(stdin.to_vec()); runtime.initialize(); athena_core::runtime::gdbstub::run_under_gdb(&mut runtime, listener, Some(function)).map( |gas_left| {