From b53282937dd92c33732d643979fedbaf60336ef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Wed, 4 Oct 2023 17:59:28 -0400 Subject: [PATCH 1/8] feat(acvm): Refactor Brillig solver struct to capture running state To allow resuming Brillig execution after resolving a foreign call and in preparation to support step-by-step execution of Brillig opcodes when run inside an ACIR block. --- acvm-repo/acvm/src/pwg/brillig.rs | 130 ++++++++++++++++++++---------- acvm-repo/acvm/src/pwg/mod.rs | 17 +++- acvm-repo/brillig_vm/src/lib.rs | 4 + 3 files changed, 104 insertions(+), 47 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 9b0ecd87492..54181921426 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -14,28 +14,79 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError}; use super::{get_value, insert_value}; -pub(super) struct BrilligSolver; +pub(super) enum BrilligSolverStatus { + Finished, + InProgress, + ForeignCallWait(ForeignCallWaitInfo), +} -impl BrilligSolver { - pub(super) fn solve( - initial_witness: &mut WitnessMap, - brillig: &Brillig, - bb_solver: &B, +pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { + witness: &'b mut WitnessMap, + brillig: &'b Brillig, + acir_index: usize, + vm: VM<'b, B>, +} + +impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { + pub(super) fn build_or_skip( + initial_witness: &'b mut WitnessMap, + brillig: &'b Brillig, + bb_solver: &'b B, acir_index: usize, - ) -> Result, OpcodeResolutionError> { + ) -> Result, OpcodeResolutionError> { + if Self::should_skip(initial_witness, brillig)? { + Self::zero_out_brillig_outputs(initial_witness, brillig)?; + return Ok(None); + } + + let vm = Self::setup_vm(initial_witness, brillig, bb_solver)?; + Ok(Some( + Self { + witness: initial_witness, + brillig, + acir_index, + vm, + } + )) + } + + fn should_skip(witness: &mut WitnessMap, brillig: &Brillig) -> Result { // If the predicate is `None`, then we simply return the value 1 // If the predicate is `Some` but we cannot find a value, then we return stalled let pred_value = match &brillig.predicate { - Some(pred) => get_value(pred, initial_witness), + Some(pred) => get_value(pred, witness), None => Ok(FieldElement::one()), }?; // A zero predicate indicates the oracle should be skipped, and its outputs zeroed. - if pred_value.is_zero() { - Self::zero_out_brillig_outputs(initial_witness, brillig)?; - return Ok(None); + Ok(pred_value.is_zero()) + } + + /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. + fn zero_out_brillig_outputs( + initial_witness: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + for output in &brillig.outputs { + match output { + BrilligOutputs::Simple(witness) => { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + BrilligOutputs::Array(witness_arr) => { + for witness in witness_arr { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + } + } } + Ok(()) + } + fn setup_vm( + witness: &mut WitnessMap, + brillig: &Brillig, + bb_solver: &'b B, + ) -> Result, OpcodeResolutionError> { // Set input values let mut input_register_values: Vec = Vec::new(); let mut input_memory: Vec = Vec::new(); @@ -45,7 +96,7 @@ impl BrilligSolver { // If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution. for input in &brillig.inputs { match input { - BrilligInputs::Single(expr) => match get_value(expr, initial_witness) { + BrilligInputs::Single(expr) => match get_value(expr, witness) { Ok(value) => input_register_values.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -57,7 +108,7 @@ impl BrilligSolver { // Attempt to fetch all array input values let memory_pointer = input_memory.len(); for expr in expr_arr.iter() { - match get_value(expr, initial_witness) { + match get_value(expr, witness) { Ok(value) => input_memory.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -76,39 +127,32 @@ impl BrilligSolver { // Instantiate a Brillig VM given the solved input registers and memory // along with the Brillig bytecode, and any present foreign call results. let input_registers = Registers::load(input_register_values); - let mut vm = VM::new( + Ok(VM::new( input_registers, input_memory, brillig.bytecode.clone(), brillig.foreign_call_results.clone(), bb_solver, - ); + )) + } + pub(super) fn solve(&mut self) -> Result { // Run the Brillig VM on these inputs, bytecode, etc! - let vm_status = vm.process_opcodes(); + while matches!(self.vm.process_opcode(), VMStatus::InProgress) {} + self.finish_execution() + } + + pub(super) fn finish_execution(&mut self) -> Result { // Check the status of the Brillig VM. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls // (when it gets foreign call results for example). + let vm_status = self.vm.get_status(); match vm_status { VMStatus::Finished => { - for (i, output) in brillig.outputs.iter().enumerate() { - let register_value = vm.get_registers().get(RegisterIndex::from(i)); - match output { - BrilligOutputs::Simple(witness) => { - insert_value(witness, register_value.to_field(), initial_witness)?; - } - BrilligOutputs::Array(witness_arr) => { - // Treat the register value as a pointer to memory - for (i, witness) in witness_arr.iter().enumerate() { - let value = &vm.get_memory()[register_value.to_usize() + i]; - insert_value(witness, value.to_field(), initial_witness)?; - } - } - } - } - Ok(None) + self.write_brillig_outputs()?; + Ok(BrilligSolverStatus::Finished) } VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"), VMStatus::Failure { message, call_stack } => { @@ -117,31 +161,31 @@ impl BrilligSolver { call_stack: call_stack .iter() .map(|brillig_index| OpcodeLocation::Brillig { - acir_index, + acir_index: self.acir_index, brillig_index: *brillig_index, }) .collect(), }) } VMStatus::ForeignCallWait { function, inputs } => { - Ok(Some(ForeignCallWaitInfo { function, inputs })) + Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs })) } } } - /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. - fn zero_out_brillig_outputs( - initial_witness: &mut WitnessMap, - brillig: &Brillig, - ) -> Result<(), OpcodeResolutionError> { - for output in &brillig.outputs { + fn write_brillig_outputs(&mut self) -> Result<(), OpcodeResolutionError> { + // Write VM execution results into the witness map + for (i, output) in self.brillig.outputs.iter().enumerate() { + let register_value = self.vm.get_registers().get(RegisterIndex::from(i)); match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, FieldElement::zero(), initial_witness)?; + insert_value(witness, register_value.to_field(), self.witness)?; } BrilligOutputs::Array(witness_arr) => { - for witness in witness_arr { - insert_value(witness, FieldElement::zero(), initial_witness)?; + // Treat the register value as a pointer to memory + for (i, witness) in witness_arr.iter().enumerate() { + let value = &self.vm.get_memory()[register_value.to_usize() + i]; + insert_value(witness, value.to_field(), self.witness)?; } } } diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 3fcf1088225..7fc94433da8 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -11,7 +11,7 @@ use acir::{ use acvm_blackbox_solver::BlackBoxResolutionError; use self::{ - arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives, + arithmetic::ArithmeticSolver, brillig::{BrilligSolver, BrilligSolverStatus}, directives::solve_directives, memory_op::MemoryOpSolver, }; use crate::{BlackBoxFunctionSolver, Language}; @@ -258,13 +258,22 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { solver.solve_memory_op(op, &mut self.witness_map, predicate) } Opcode::Brillig(brillig) => { - match BrilligSolver::solve( + let result = BrilligSolver::build_or_skip( &mut self.witness_map, brillig, self.backend, self.instruction_pointer, - ) { - Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), + ); + match result { + Ok(Some(mut solver)) => { + match solver.solve() { + Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => + return self.wait_for_foreign_call(foreign_call), + Ok(BrilligSolverStatus::InProgress) => + unreachable!("Brillig solver still in progress"), + res => res.map(|_| ()), + } + } res => res.map(|_| ()), } } diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index ca8c52756d1..84165dbc8ec 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -112,6 +112,10 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { status } + pub fn get_status(&self) -> VMStatus { + self.status.clone() + } + /// Sets the current status of the VM to Finished (completed execution). fn finish(&mut self) -> VMStatus { self.status(VMStatus::Finished) From 477437ed024b341944c8033f005def5faa6798e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Fri, 6 Oct 2023 18:13:28 -0400 Subject: [PATCH 2/8] feat(brillig): Allow insertion of foreign calls results Add a new method `resolve_foreign_call` to push the result to the results vector. Then the VM can be resumed and retrieve the pending foreign call result without restarting. --- acvm-repo/brillig_vm/src/lib.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 84165dbc8ec..3f961286bb8 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -131,6 +131,14 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { self.status(VMStatus::ForeignCallWait { function, inputs }) } + pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + if self.foreign_call_counter < self.foreign_call_results.len() { + panic!("No unresolved foreign calls"); + } + self.foreign_call_results.push(foreign_call_result); + self.status(VMStatus::InProgress); + } + /// Sets the current status of the VM to `fail`. /// Indicating that the VM encountered a `Trap` Opcode /// or an invalid state. @@ -933,7 +941,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push( + vm.resolve_foreign_call( Value::from(10u128).into(), // Result of doubling 5u128 ); @@ -994,7 +1002,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -1067,7 +1075,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { + vm.resolve_foreign_call(ForeignCallResult { values: vec![ForeignCallParam::Array(output_string.clone())], }); @@ -1129,7 +1137,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -1214,7 +1222,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); From fe1fb0fc58e3de826c6c7c44cd2bfddeaf77e3ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Fri, 6 Oct 2023 18:33:11 -0400 Subject: [PATCH 3/8] feat(acvm): Save Brillig solver in ACVM This allows retaining state and reusing the Brillig VM when resuming execution after a pending foreign call is resolved. This is also in preparation to allow stepping into the Brillig code when debugging an ACIR program. --- acvm-repo/acvm/src/pwg/brillig.rs | 88 +++++++++++++++++++------------ acvm-repo/acvm/src/pwg/mod.rs | 59 +++++++++++++-------- 2 files changed, 90 insertions(+), 57 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 54181921426..e205750a176 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -1,5 +1,5 @@ use acir::{ - brillig::{ForeignCallParam, RegisterIndex, Value}, + brillig::{ForeignCallParam, ForeignCallResult, RegisterIndex, Value}, circuit::{ brillig::{Brillig, BrilligInputs, BrilligOutputs}, OpcodeLocation, @@ -21,16 +21,17 @@ pub(super) enum BrilligSolverStatus { } pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { - witness: &'b mut WitnessMap, - brillig: &'b Brillig, - acir_index: usize, vm: VM<'b, B>, + acir_index: usize, } impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { - pub(super) fn build_or_skip( - initial_witness: &'b mut WitnessMap, - brillig: &'b Brillig, + /// Constructs a solver for a Brillig block given the bytecode and initial + /// witness. If the block should be skipped entirely because its predicate + /// evaluates to false, zero out the block outputs and return Ok(None). + pub(super) fn build_or_skip<'w>( + initial_witness: &'w mut WitnessMap, + brillig: &'w Brillig, bb_solver: &'b B, acir_index: usize, ) -> Result, OpcodeResolutionError> { @@ -39,18 +40,11 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { return Ok(None); } - let vm = Self::setup_vm(initial_witness, brillig, bb_solver)?; - Ok(Some( - Self { - witness: initial_witness, - brillig, - acir_index, - vm, - } - )) + let vm = Self::build_vm(initial_witness, brillig, bb_solver)?; + Ok(Some(Self { vm, acir_index })) } - fn should_skip(witness: &mut WitnessMap, brillig: &Brillig) -> Result { + fn should_skip(witness: &WitnessMap, brillig: &Brillig) -> Result { // If the predicate is `None`, then we simply return the value 1 // If the predicate is `Some` but we cannot find a value, then we return stalled let pred_value = match &brillig.predicate { @@ -82,8 +76,8 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(()) } - fn setup_vm( - witness: &mut WitnessMap, + fn build_vm( + witness: &WitnessMap, brillig: &Brillig, bb_solver: &'b B, ) -> Result, OpcodeResolutionError> { @@ -137,24 +131,21 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { } pub(super) fn solve(&mut self) -> Result { - // Run the Brillig VM on these inputs, bytecode, etc! - while matches!(self.vm.process_opcode(), VMStatus::InProgress) {} - - self.finish_execution() + let status = self.vm.process_opcodes(); + self.handle_vm_status(status) } - pub(super) fn finish_execution(&mut self) -> Result { - // Check the status of the Brillig VM. + fn handle_vm_status( + &self, + vm_status: VMStatus, + ) -> Result { + // Check the status of the Brillig VM and return a resolution. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls // (when it gets foreign call results for example). - let vm_status = self.vm.get_status(); match vm_status { - VMStatus::Finished => { - self.write_brillig_outputs()?; - Ok(BrilligSolverStatus::Finished) - } - VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"), + VMStatus::Finished => Ok(BrilligSolverStatus::Finished), + VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress), VMStatus::Failure { message, call_stack } => { Err(OpcodeResolutionError::BrilligFunctionFailed { message, @@ -173,25 +164,52 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { } } - fn write_brillig_outputs(&mut self) -> Result<(), OpcodeResolutionError> { + pub(super) fn finalize( + self, + witness: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + // Finish the Brillig execution by writing the outputs to the witness map + let vm_status = self.vm.get_status(); + match vm_status { + VMStatus::Finished => { + self.write_brillig_outputs(witness, brillig)?; + Ok(()) + } + _ => panic!("Brillig VM has not completed execution"), + } + } + + fn write_brillig_outputs( + &self, + witness_map: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { // Write VM execution results into the witness map - for (i, output) in self.brillig.outputs.iter().enumerate() { + for (i, output) in brillig.outputs.iter().enumerate() { let register_value = self.vm.get_registers().get(RegisterIndex::from(i)); match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, register_value.to_field(), self.witness)?; + insert_value(witness, register_value.to_field(), witness_map)?; } BrilligOutputs::Array(witness_arr) => { // Treat the register value as a pointer to memory for (i, witness) in witness_arr.iter().enumerate() { let value = &self.vm.get_memory()[register_value.to_usize() + i]; - insert_value(witness, value.to_field(), self.witness)?; + insert_value(witness, value.to_field(), witness_map)?; } } } } Ok(()) } + + pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + match self.vm.get_status() { + VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result), + _ => unreachable!("Brillig VM is not waiting for a foreign call"), + } + } } /// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall] diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 7fc94433da8..532e7fbd0e0 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -11,7 +11,9 @@ use acir::{ use acvm_blackbox_solver::BlackBoxResolutionError; use self::{ - arithmetic::ArithmeticSolver, brillig::{BrilligSolver, BrilligSolverStatus}, directives::solve_directives, + arithmetic::ArithmeticSolver, + brillig::{BrilligSolver, BrilligSolverStatus}, + directives::solve_directives, memory_op::MemoryOpSolver, }; use crate::{BlackBoxFunctionSolver, Language}; @@ -140,6 +142,8 @@ pub struct ACVM<'backend, B: BlackBoxFunctionSolver> { instruction_pointer: usize, witness_map: WitnessMap, + + brillig_solver: Option>, } impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { @@ -152,6 +156,7 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { opcodes, instruction_pointer: 0, witness_map: initial_witness, + brillig_solver: None, } } @@ -216,12 +221,8 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { panic!("ACVM is not expecting a foreign call response as no call was made"); } - // We want to inject the foreign call result into the brillig opcode which initiated the call. - let opcode = &mut self.opcodes[self.instruction_pointer]; - let Opcode::Brillig(brillig) = opcode else { - unreachable!("ACVM can only enter `RequiresForeignCall` state on a Brillig opcode"); - }; - brillig.foreign_call_results.push(foreign_call_result); + let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver"); + brillig_solver.resolve_pending_foreign_call(foreign_call_result); // Now that the foreign call has been resolved then we can resume execution. self.status(ACVMStatus::InProgress); @@ -258,22 +259,36 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { solver.solve_memory_op(op, &mut self.witness_map, predicate) } Opcode::Brillig(brillig) => { - let result = BrilligSolver::build_or_skip( - &mut self.witness_map, - brillig, - self.backend, - self.instruction_pointer, - ); - match result { - Ok(Some(mut solver)) => { - match solver.solve() { - Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => - return self.wait_for_foreign_call(foreign_call), - Ok(BrilligSolverStatus::InProgress) => - unreachable!("Brillig solver still in progress"), - res => res.map(|_| ()), + let witness = &mut self.witness_map; + // get the active Brillig solver, or try to build one if necessary + // (Brillig execution maybe bypassed by constraints) + let maybe_solver = match self.brillig_solver.as_mut() { + Some(solver) => Ok(Some(solver)), + None => BrilligSolver::build_or_skip( + witness, + brillig, + self.backend, + self.instruction_pointer, + ) + .and_then(|optional_solver| { + Ok(optional_solver + .and_then(|solver| Some(self.brillig_solver.insert(solver)))) + }), + }; + match maybe_solver { + Ok(Some(solver)) => match solver.solve() { + Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => { + return self.wait_for_foreign_call(foreign_call); } - } + Ok(BrilligSolverStatus::InProgress) => { + unreachable!("Brillig solver still in progress") + } + Ok(BrilligSolverStatus::Finished) => { + // clear active Brillig solver and write execution outputs + self.brillig_solver.take().unwrap().finalize(witness, brillig) + } + res => res.map(|_| ()), + }, res => res.map(|_| ()), } } From 07e3988e507925899ca9af122db8e6fa4d4e1e2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Fri, 6 Oct 2023 19:40:56 -0400 Subject: [PATCH 4/8] chore: Resolve clippy warnings --- acvm-repo/acvm/src/pwg/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 532e7fbd0e0..1ac17b135c3 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -270,9 +270,8 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { self.backend, self.instruction_pointer, ) - .and_then(|optional_solver| { - Ok(optional_solver - .and_then(|solver| Some(self.brillig_solver.insert(solver)))) + .map(|optional_solver| { + optional_solver.map(|solver| self.brillig_solver.insert(solver)) }), }; match maybe_solver { From 1147e841d8d83ca9e93d02974d95adb931d52c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Tue, 10 Oct 2023 16:54:21 -0400 Subject: [PATCH 5/8] feat: Address code review to simplify code and improve readability --- acvm-repo/acvm/src/pwg/brillig.rs | 57 ++++++++++--------------- acvm-repo/acvm/src/pwg/mod.rs | 71 +++++++++++++++++-------------- 2 files changed, 60 insertions(+), 68 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index e205750a176..12b408760a7 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -26,38 +26,21 @@ pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { } impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { - /// Constructs a solver for a Brillig block given the bytecode and initial - /// witness. If the block should be skipped entirely because its predicate - /// evaluates to false, zero out the block outputs and return Ok(None). - pub(super) fn build_or_skip<'w>( - initial_witness: &'w mut WitnessMap, - brillig: &'w Brillig, - bb_solver: &'b B, - acir_index: usize, - ) -> Result, OpcodeResolutionError> { - if Self::should_skip(initial_witness, brillig)? { - Self::zero_out_brillig_outputs(initial_witness, brillig)?; - return Ok(None); - } - - let vm = Self::build_vm(initial_witness, brillig, bb_solver)?; - Ok(Some(Self { vm, acir_index })) - } - - fn should_skip(witness: &WitnessMap, brillig: &Brillig) -> Result { - // If the predicate is `None`, then we simply return the value 1 + /// Evaluates if the Brillig block should be skipped entirely + pub(super) fn should_skip( + witness: &WitnessMap, + brillig: &Brillig, + ) -> Result { + // If the predicate is `None`, the block should never be skipped // If the predicate is `Some` but we cannot find a value, then we return stalled - let pred_value = match &brillig.predicate { - Some(pred) => get_value(pred, witness), - None => Ok(FieldElement::one()), - }?; - - // A zero predicate indicates the oracle should be skipped, and its outputs zeroed. - Ok(pred_value.is_zero()) + match &brillig.predicate { + Some(pred) => Ok(get_value(pred, witness)?.is_zero()), + None => Ok(false), + } } /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. - fn zero_out_brillig_outputs( + pub(super) fn zero_out_brillig_outputs( initial_witness: &mut WitnessMap, brillig: &Brillig, ) -> Result<(), OpcodeResolutionError> { @@ -76,11 +59,14 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(()) } - fn build_vm( - witness: &WitnessMap, + /// Constructs a solver for a Brillig block given the bytecode and initial + /// witness. + pub(super) fn new( + initial_witness: &mut WitnessMap, brillig: &Brillig, bb_solver: &'b B, - ) -> Result, OpcodeResolutionError> { + acir_index: usize, + ) -> Result { // Set input values let mut input_register_values: Vec = Vec::new(); let mut input_memory: Vec = Vec::new(); @@ -90,7 +76,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution. for input in &brillig.inputs { match input { - BrilligInputs::Single(expr) => match get_value(expr, witness) { + BrilligInputs::Single(expr) => match get_value(expr, initial_witness) { Ok(value) => input_register_values.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -102,7 +88,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // Attempt to fetch all array input values let memory_pointer = input_memory.len(); for expr in expr_arr.iter() { - match get_value(expr, witness) { + match get_value(expr, initial_witness) { Ok(value) => input_memory.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -121,13 +107,14 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // Instantiate a Brillig VM given the solved input registers and memory // along with the Brillig bytecode, and any present foreign call results. let input_registers = Registers::load(input_register_values); - Ok(VM::new( + let vm = VM::new( input_registers, input_memory, brillig.bytecode.clone(), brillig.foreign_call_results.clone(), bb_solver, - )) + ); + Ok(Self { vm, acir_index }) } pub(super) fn solve(&mut self) -> Result { diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 1ac17b135c3..bd672906369 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -258,39 +258,10 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { let solver = self.block_solvers.entry(*block_id).or_default(); solver.solve_memory_op(op, &mut self.witness_map, predicate) } - Opcode::Brillig(brillig) => { - let witness = &mut self.witness_map; - // get the active Brillig solver, or try to build one if necessary - // (Brillig execution maybe bypassed by constraints) - let maybe_solver = match self.brillig_solver.as_mut() { - Some(solver) => Ok(Some(solver)), - None => BrilligSolver::build_or_skip( - witness, - brillig, - self.backend, - self.instruction_pointer, - ) - .map(|optional_solver| { - optional_solver.map(|solver| self.brillig_solver.insert(solver)) - }), - }; - match maybe_solver { - Ok(Some(solver)) => match solver.solve() { - Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => { - return self.wait_for_foreign_call(foreign_call); - } - Ok(BrilligSolverStatus::InProgress) => { - unreachable!("Brillig solver still in progress") - } - Ok(BrilligSolverStatus::Finished) => { - // clear active Brillig solver and write execution outputs - self.brillig_solver.take().unwrap().finalize(witness, brillig) - } - res => res.map(|_| ()), - }, - res => res.map(|_| ()), - } - } + Opcode::Brillig(_) => match self.solve_brillig_opcode() { + Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), + res => res.map(|_| ()), + }, }; match resolution { Ok(()) => { @@ -324,6 +295,40 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { } } } + + fn solve_brillig_opcode( + &mut self, + ) -> Result, OpcodeResolutionError> { + let brillig = match &self.opcodes[self.instruction_pointer] { + Opcode::Brillig(brillig) => brillig, + _ => unreachable!("Not executing a Brillig opcode"), + }; + let witness = &mut self.witness_map; + if BrilligSolver::::should_skip(witness, brillig)? { + BrilligSolver::::zero_out_brillig_outputs(witness, brillig).map(|_| None) + } else { + let mut solver = match self.brillig_solver.take() { + None => { + BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)? + } + Some(solver) => solver, + }; + match solver.solve()? { + BrilligSolverStatus::ForeignCallWait(foreign_call) => { + _ = self.brillig_solver.insert(solver); + Ok(Some(foreign_call)) + } + BrilligSolverStatus::InProgress => { + unreachable!("Brillig solver still in progress") + } + BrilligSolverStatus::Finished => { + // Write execution outputs + solver.finalize(witness, brillig)?; + Ok(None) + } + } + } + } } // Returns the concrete value for a particular witness From 58629cad4233f779efb9fc714498368c5ffd112c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Tue, 10 Oct 2023 17:08:10 -0400 Subject: [PATCH 6/8] chore: Use let-else instead of match --- acvm-repo/acvm/src/pwg/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index bd672906369..7407966d796 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -299,9 +299,8 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { fn solve_brillig_opcode( &mut self, ) -> Result, OpcodeResolutionError> { - let brillig = match &self.opcodes[self.instruction_pointer] { - Opcode::Brillig(brillig) => brillig, - _ => unreachable!("Not executing a Brillig opcode"), + let Opcode::Brillig(brillig) = &self.opcodes[self.instruction_pointer] else { + unreachable!("Not executing a Brillig opcode"); }; let witness = &mut self.witness_map; if BrilligSolver::::should_skip(witness, brillig)? { From 8567062f7186bbd1b22c6850a6bac0a860afb2cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Tue, 10 Oct 2023 18:36:42 -0400 Subject: [PATCH 7/8] chore: Apply suggestions from code review Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- acvm-repo/acvm/src/pwg/mod.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 7407966d796..5b61e3910fa 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -306,15 +306,18 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { if BrilligSolver::::should_skip(witness, brillig)? { BrilligSolver::::zero_out_brillig_outputs(witness, brillig).map(|_| None) } else { - let mut solver = match self.brillig_solver.take() { + // If we're resuming execution after resolving a foreign call then + // there will be a cached `BrilligSolver` to avoid recomputation. + let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() { + Some(solver) => solver, None => { BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)? } - Some(solver) => solver, }; match solver.solve()? { BrilligSolverStatus::ForeignCallWait(foreign_call) => { - _ = self.brillig_solver.insert(solver); + // Cache the current state of the solver + self.brillig_solver = Some(solver); Ok(Some(foreign_call)) } BrilligSolverStatus::InProgress => { From 43217387dea9fcc4dc30b61be5aeec5d796b3cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Wed, 11 Oct 2023 19:03:46 -0400 Subject: [PATCH 8/8] chore: Format with cargo fmt --- acvm-repo/acvm/src/pwg/brillig.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index f035015648a..6fc54d42eab 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -107,13 +107,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // Instantiate a Brillig VM given the solved input registers and memory // along with the Brillig bytecode. let input_registers = Registers::load(input_register_values); - let vm = VM::new( - input_registers, - input_memory, - &brillig.bytecode, - vec![], - bb_solver, - ); + let vm = VM::new(input_registers, input_memory, &brillig.bytecode, vec![], bb_solver); Ok(Self { vm, acir_index }) }