Skip to content

Commit

Permalink
feat: optimize runtime speed on fast (#1373)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Aug 26, 2024
1 parent a7afc1b commit e8efd00
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 71 deletions.
163 changes: 99 additions & 64 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ pub struct Executor<'a> {
/// The maximum number of cycles for a syscall.
pub max_syscall_cycles: u32,

/// Whether to emit events during execution.
pub emit_events: bool,
/// The mode the executor is running in.
pub executor_mode: ExecutorMode,

/// Report of the program execution.
pub report: ExecutionReport,
Expand All @@ -102,6 +102,17 @@ pub struct Executor<'a> {
pub memory_checkpoint: HashMap<u32, Option<MemoryRecord>, BuildNoHashHasher<u32>>,
}

/// The different modes the executor can run in.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutorMode {
/// Run the execution with no tracing or checkpointing.
Simple,
/// Run the execution with checkpoints for memory.
Checkpoint,
/// Run the execution with full tracing of events.
Trace,
}

/// Errors that the [``Executor``] can throw.
#[derive(Error, Debug, Serialize, Deserialize)]
pub enum ExecutionError {
Expand Down Expand Up @@ -192,7 +203,7 @@ impl<'a> Executor<'a> {
unconstrained: false,
unconstrained_state: ForkState::default(),
syscall_map,
emit_events: true,
executor_mode: ExecutorMode::Trace,
max_syscall_cycles,
report: ExecutionReport::default(),
print_report: false,
Expand Down Expand Up @@ -239,15 +250,22 @@ impl<'a> Executor<'a> {
let mut registers = [0; 32];
for i in 0..32 {
let addr = Register::from_u32(i as u32) as u32;
registers[i] = match self.state.memory.get(&addr) {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert(Some(*record));
record.value
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
0
let record = self.state.memory.get(&addr);

if self.executor_mode != ExecutorMode::Simple {
match record {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert_with(|| Some(*record));
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
}
}
}

registers[i] = match record {
Some(record) => record.value,
None => 0,
};
}
registers
Expand All @@ -257,33 +275,46 @@ impl<'a> Executor<'a> {
#[must_use]
pub fn register(&mut self, register: Register) -> u32 {
let addr = register as u32;
#[allow(clippy::single_match_else)]
match self.state.memory.get(&addr) {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert(Some(*record));
record.value
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
0
let record = self.state.memory.get(&addr);

if self.executor_mode != ExecutorMode::Simple {
match record {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert_with(|| Some(*record));
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
}
}
}

match record {
Some(record) => record.value,
None => 0,
}
}

/// Get the current value of a word.
#[must_use]
pub fn word(&mut self, addr: u32) -> u32 {
#[allow(clippy::single_match_else)]
match self.state.memory.get(&addr) {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert(Some(*record));
record.value
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
0
let record = self.state.memory.get(&addr);

if self.executor_mode != ExecutorMode::Simple {
match record {
Some(record) => {
self.memory_checkpoint.entry(addr).or_insert_with(|| Some(*record));
}
None => {
self.memory_checkpoint.entry(addr).or_insert(None);
}
}
}

match record {
Some(record) => record.value,
None => 0,
}
}

/// Get the current value of a byte.
Expand Down Expand Up @@ -317,13 +348,15 @@ impl<'a> Executor<'a> {
pub fn mr(&mut self, addr: u32, shard: u32, timestamp: u32) -> MemoryReadRecord {
// Get the memory record entry.
let entry = self.state.memory.entry(addr);
match entry {
Entry::Occupied(ref entry) => {
let record = entry.get();
self.memory_checkpoint.entry(addr).or_insert(Some(*record));
}
Entry::Vacant(_) => {
self.memory_checkpoint.entry(addr).or_insert(None);
if self.executor_mode != ExecutorMode::Simple {
match entry {
Entry::Occupied(ref entry) => {
let record = entry.get();
self.memory_checkpoint.entry(addr).or_insert_with(|| Some(*record));
}
Entry::Vacant(_) => {
self.memory_checkpoint.entry(addr).or_insert(None);
}
}
}

Expand Down Expand Up @@ -360,13 +393,15 @@ impl<'a> Executor<'a> {
pub fn mw(&mut self, addr: u32, value: u32, shard: u32, timestamp: u32) -> MemoryWriteRecord {
// Get the memory record entry.
let entry = self.state.memory.entry(addr);
match entry {
Entry::Occupied(ref entry) => {
let record = entry.get();
self.memory_checkpoint.entry(addr).or_insert(Some(*record));
}
Entry::Vacant(_) => {
self.memory_checkpoint.entry(addr).or_insert(None);
if self.executor_mode != ExecutorMode::Simple {
match entry {
Entry::Occupied(ref entry) => {
let record = entry.get();
self.memory_checkpoint.entry(addr).or_insert_with(|| Some(*record));
}
Entry::Vacant(_) => {
self.memory_checkpoint.entry(addr).or_insert(None);
}
}
}

Expand Down Expand Up @@ -410,7 +445,7 @@ impl<'a> Executor<'a> {
let record = self.mr(addr, self.shard(), self.timestamp(&position));

// If we're not in unconstrained mode, record the access for the current cycle.
if !self.unconstrained && self.emit_events {
if !self.unconstrained && self.executor_mode == ExecutorMode::Trace {
match position {
MemoryAccessPosition::A => self.memory_accesses.a = Some(record.into()),
MemoryAccessPosition::B => self.memory_accesses.b = Some(record.into()),
Expand All @@ -435,7 +470,7 @@ impl<'a> Executor<'a> {
let record = self.mw(addr, value, self.shard(), self.timestamp(&position));

// If we're not in unconstrained mode, record the access for the current cycle.
if !self.unconstrained {
if !self.unconstrained && self.executor_mode == ExecutorMode::Trace {
match position {
MemoryAccessPosition::A => {
assert!(self.memory_accesses.a.is_none());
Expand Down Expand Up @@ -596,7 +631,7 @@ impl<'a> Executor<'a> {
lookup_id: u128,
) {
self.rw(rd, a);
if self.emit_events {
if self.executor_mode == ExecutorMode::Trace {
self.emit_alu(self.state.clk, instruction.opcode, a, b, c, lookup_id);
}
}
Expand Down Expand Up @@ -649,10 +684,14 @@ impl<'a> Executor<'a> {
let (a, b, c): (u32, u32, u32);
let (addr, memory_read_value): (u32, u32);
let mut memory_store_value: Option<u32> = None;
self.memory_accesses = MemoryAccessRecord::default();

let lookup_id = create_alu_lookup_id();
let syscall_lookup_id = create_alu_lookup_id();
if self.executor_mode != ExecutorMode::Simple {
self.memory_accesses = MemoryAccessRecord::default();
}
let lookup_id =
if self.executor_mode == ExecutorMode::Simple { 0 } else { create_alu_lookup_id() };
let syscall_lookup_id =
if self.executor_mode == ExecutorMode::Simple { 0 } else { create_alu_lookup_id() };

if self.print_report && !self.unconstrained {
self.report
Expand Down Expand Up @@ -1027,7 +1066,7 @@ impl<'a> Executor<'a> {
}

// Emit the CPU event for this cycle.
if self.emit_events {
if self.executor_mode == ExecutorMode::Trace {
self.emit_cpu(
self.shard(),
channel,
Expand Down Expand Up @@ -1100,7 +1139,7 @@ impl<'a> Executor<'a> {
///
/// This function will return an error if the program execution fails.
pub fn execute_record(&mut self) -> Result<(Vec<ExecutionRecord>, bool), ExecutionError> {
self.emit_events = true;
self.executor_mode = ExecutorMode::Trace;
self.print_report = true;
let done = self.execute()?;
Ok((std::mem::take(&mut self.records), done))
Expand All @@ -1114,7 +1153,7 @@ impl<'a> Executor<'a> {
/// This function will return an error if the program execution fails.
pub fn execute_state(&mut self) -> Result<(ExecutionState, bool), ExecutionError> {
self.memory_checkpoint.clear();
self.emit_events = false;
self.executor_mode = ExecutorMode::Checkpoint;
self.print_report = true;

// Take memory out and then clone so that memory is not cloned.
Expand Down Expand Up @@ -1163,8 +1202,8 @@ impl<'a> Executor<'a> {
/// # Errors
///
/// This function will return an error if the program execution fails.
pub fn run_untraced(&mut self) -> Result<(), ExecutionError> {
self.emit_events = false;
pub fn run_fast(&mut self) -> Result<(), ExecutionError> {
self.executor_mode = ExecutorMode::Simple;
self.print_report = true;
while !self.execute()? {}
Ok(())
Expand All @@ -1176,22 +1215,12 @@ impl<'a> Executor<'a> {
///
/// This function will return an error if the program execution fails.
pub fn run(&mut self) -> Result<(), ExecutionError> {
self.emit_events = true;
self.executor_mode = ExecutorMode::Trace;
self.print_report = true;
while !self.execute()? {}
Ok(())
}

/// Executes the program without emitting events.
///
/// # Panics
///
/// This function will panic if the program execution fails.
pub fn dry_run(&mut self) {
self.emit_events = false;
while !self.execute().unwrap() {}
}

/// Executes up to `self.shard_batch_size` cycles of the program, returning whether the program
/// has finished.
fn execute(&mut self) -> Result<bool, ExecutionError> {
Expand Down Expand Up @@ -1359,6 +1388,12 @@ impl<'a> Executor<'a> {
}
}

impl Default for ExecutorMode {
fn default() -> Self {
Self::Simple
}
}

// TODO: FIX
/// Aligns an address to the nearest word below or equal to it.
#[must_use]
Expand Down
3 changes: 2 additions & 1 deletion crates/core/executor/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
events::MemoryRecord,
record::{ExecutionRecord, MemoryAccessRecord},
syscalls::SyscallCode,
ExecutorMode,
};

/// Holds data describing the current state of a program's execution.
Expand Down Expand Up @@ -116,7 +117,7 @@ pub struct ForkState {
/// The original execution record at the fork point.
pub record: ExecutionRecord,
/// Whether `emit_events` was enabled at the fork point.
pub emit_events: bool,
pub executor_mode: ExecutorMode,
}

impl ExecutionState {
Expand Down
8 changes: 4 additions & 4 deletions crates/core/executor/src/syscalls/unconstrained.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use hashbrown::HashMap;

use crate::state::ForkState;
use crate::{state::ForkState, ExecutorMode};

use super::{Syscall, SyscallContext};

Expand All @@ -19,9 +19,9 @@ impl Syscall for EnterUnconstrainedSyscall {
memory_diff: HashMap::default(),
record: std::mem::take(&mut ctx.rt.record),
op_record: std::mem::take(&mut ctx.rt.memory_accesses),
emit_events: ctx.rt.emit_events,
executor_mode: ctx.rt.executor_mode,
};
ctx.rt.emit_events = false;
ctx.rt.executor_mode = ExecutorMode::Simple;
Some(1)
}
}
Expand All @@ -48,7 +48,7 @@ impl Syscall for ExitUnconstrainedSyscall {
}
ctx.rt.record = std::mem::take(&mut ctx.rt.unconstrained_state.record);
ctx.rt.memory_accesses = std::mem::take(&mut ctx.rt.unconstrained_state.op_record);
ctx.rt.emit_events = ctx.rt.unconstrained_state.emit_events;
ctx.rt.executor_mode = ctx.rt.unconstrained_state.executor_mode;
ctx.rt.unconstrained = false;
}
ctx.rt.unconstrained_state = ForkState::default();
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
for (proof, vkey) in stdin.proofs.iter() {
runtime.write_proof(proof.clone(), vkey.clone());
}
runtime.run_untraced()?;
runtime.run_fast()?;
Ok((SP1PublicValues::from(&runtime.state.public_values_stream), runtime.report))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn get_cycles(elf: &[u8], stdin: &SP1Stdin) -> u64 {
let program = Program::from(elf).unwrap();
let mut runtime = Executor::new(program, SP1CoreOpts::default());
runtime.write_vecs(&stdin.buffer);
runtime.dry_run();
runtime.run_fast().unwrap();
runtime.state.global_clk
}

Expand Down

0 comments on commit e8efd00

Please sign in to comment.