Skip to content

Commit

Permalink
Merge pull request #296 from athenavm/integrate-runtime-hooks
Browse files Browse the repository at this point in the history
integrate runtime hooks
  • Loading branch information
poszu authored Jan 3, 2025
2 parents 54e3b36 + fe4912a commit 3f2b321
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 8 deletions.
55 changes: 55 additions & 0 deletions core/src/runtime/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::collections::HashMap;

use super::Runtime;

#[mockall::automock]
#[allow(clippy::needless_lifetimes)] // the lifetimes are needed by automock
pub trait Hook {
fn execute<'r, 'h>(&self, env: HookEnv<'r, 'h>, data: &[u8]) -> anyhow::Result<Vec<u8>>;
}

/// A registry of hooks to call, indexed by the file descriptors through which they are accessed.
#[derive(Default)]
pub struct HookRegistry {
/// Table of registered hooks.
table: HashMap<u32, Box<dyn Hook>>,
}

impl HookRegistry {
/// Create a registry with the default hooks.
pub fn new() -> Self {
Default::default()
}

/// Register a hook under a given FD.
/// Will fail if a hook is already registered on a given FD
/// or a FD is <= 4.
pub fn register(&mut self, fd: u32, hook: Box<dyn Hook>) -> anyhow::Result<()> {
anyhow::ensure!(fd > 4, "FDs 0-4 are reserved for internal usage");
anyhow::ensure!(
!self.table.contains_key(&fd),
"there is already a hook for FD {fd} registered"
);
self.table.insert(fd, hook);
Ok(())
}

pub(crate) fn get(&self, fd: u32) -> Option<&dyn Hook> {
self.table.get(&fd).map(AsRef::as_ref)
}
}

/// Environment that a hook may read from.
pub struct HookEnv<'r, 'h> {
pub runtime: &'r Runtime<'h>,
}

#[cfg(test)]
pub mod tests {
use super::*;

#[test]
pub fn registry_new_is_empty() {
assert!(HookRegistry::new().table.is_empty());
}
}
71 changes: 70 additions & 1 deletion core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod gdbstub;
pub mod hooks;
mod instruction;
mod io;
mod opcode;
Expand All @@ -8,6 +9,7 @@ mod state;
mod syscall;
mod utils;

use anyhow::anyhow;
use athena_interface::MethodSelector;
pub use instruction::*;
pub use opcode::*;
Expand Down Expand Up @@ -75,6 +77,8 @@ pub struct Runtime<'host> {
pub max_syscall_cycles: u32,

breakpoints: BTreeSet<u32>,

hook_registry: hooks::HookRegistry,
}

/// An record of a write to a memory address.
Expand Down Expand Up @@ -165,6 +169,7 @@ impl<'host> Runtime<'host> {
syscall_map,
max_syscall_cycles,
breakpoints: BTreeSet::new(),
hook_registry: Default::default(),
}
}

Expand Down Expand Up @@ -813,6 +818,20 @@ impl<'host> Runtime<'host> {
fn get_syscall(&mut self, code: SyscallCode) -> Option<&Arc<dyn Syscall>> {
self.syscall_map.get(&code)
}

/// Register a new hook.
/// Will fail if a hook is already registered on a given FD
/// or a FD is <= 4.
pub fn register_hook(&mut self, fd: u32, hook: Box<dyn hooks::Hook>) -> anyhow::Result<()> {
self.hook_registry.register(fd, hook)
}
pub(crate) fn execute_hook(&self, fd: u32, data: &[u8]) -> anyhow::Result<Vec<u8>> {
let hook = self
.hook_registry
.get(fd)
.ok_or_else(|| anyhow!("no hook registered for FD {fd}"))?;
hook.execute(hooks::HookEnv { runtime: self }, data)
}
}

#[cfg(test)]
Expand All @@ -828,7 +847,10 @@ pub mod tests {
utils::{tests::TEST_PANIC_ELF, AthenaCoreOpts},
};

use super::syscall::SyscallCode;
use super::{
hooks::{self, MockHook},
syscall::SyscallCode,
};
use super::{Instruction, Opcode, Program, Runtime};

pub fn simple_program() -> Program {
Expand Down Expand Up @@ -1739,6 +1761,53 @@ pub mod tests {
assert_eq!(runtime.register(Register::X12), 0x12346525);
assert_eq!(runtime.register(Register::X11), 0x65256525);
}

#[test]
fn registering_hook() {
let mut runtime = Runtime::new(
Program::new(vec![], 0, 0),
None,
AthenaCoreOpts::default(),
None,
);

// can't register on FD <= 4
for fd in 0..=4 {
runtime
.register_hook(fd, Box::new(hooks::MockHook::new()))
.unwrap_err();
}

// register on 5
runtime
.register_hook(5, Box::new(hooks::MockHook::new()))
.unwrap();

// can't register another hook on occupied FD
runtime
.register_hook(5, Box::new(hooks::MockHook::new()))
.unwrap_err();
}

#[test]
fn executing_hook() {
let mut runtime = Runtime::new(
Program::new(vec![], 0, 0),
None,
AthenaCoreOpts::default(),
None,
);

let mut hook = Box::new(MockHook::new());
hook.expect_execute().returning(|_, data| {
assert_eq!(&[1, 2, 3, 4], data);
Ok(vec![5, 6, 7])
});
runtime.register_hook(5, hook).unwrap();

let result = runtime.execute_hook(5, &[1, 2, 3, 4]).unwrap();
assert_eq!(vec![5, 6, 7], result);
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down
64 changes: 57 additions & 7 deletions core/src/syscall/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ pub(crate) struct SyscallWrite;

impl Syscall for SyscallWrite {
fn execute(&self, ctx: &mut SyscallContext, fd: u32, write_buf: u32) -> SyscallResult {
let nbytes = ctx.rt.register(Register::X12);
let bytes = ctx.bytes(write_buf, nbytes as usize);
let rt = &mut ctx.rt;
let nbytes = rt.register(Register::X12);
// Read nbytes from memory starting at write_buf.
let bytes = (0..nbytes)
.map(|i| rt.byte(write_buf + i))
.collect::<Vec<u8>>();
match fd {
1 => {
let s = core::str::from_utf8(&bytes).or(Err(StatusCode::InvalidSyscallArgument))?;
Expand Down Expand Up @@ -75,8 +72,16 @@ impl Syscall for SyscallWrite {
rt.state.input_stream.push(bytes);
}
fd => {
tracing::debug!("syscall write called with invalid fd: {fd}");
return Err(StatusCode::InvalidSyscallArgument);
tracing::debug!(fd, "executing hook");
match rt.execute_hook(fd, &bytes) {
Ok(result) => {
rt.state.input_stream.push(result);
}
Err(err) => {
tracing::debug!(fd, ?err, "hook failed");
return Err(StatusCode::InvalidSyscallArgument);
}
}
}
}
Ok(Outcome::Result(None))
Expand All @@ -101,3 +106,48 @@ fn update_io_buf(ctx: &mut SyscallContext, fd: u32, s: &str) -> Vec<String> {
vec![]
}
}

#[cfg(test)]
mod tests {
use athena_interface::StatusCode;

use crate::{
runtime::{hooks, Program, Runtime, Syscall, SyscallContext},
utils::AthenaCoreOpts,
};

use super::SyscallWrite;

#[test]
fn invoking_nonexisting_hook_fails() {
let mut runtime = Runtime::new(
Program::new(vec![], 0, 0),
None,
AthenaCoreOpts::default(),
None,
);

let result = SyscallWrite {}.execute(&mut SyscallContext { rt: &mut runtime }, 7, 0);
assert_eq!(Err(StatusCode::InvalidSyscallArgument), result);
}

#[test]
fn invoking_registered_hook() {
let mut hook_mock = hooks::MockHook::new();
let _ = hook_mock
.expect_execute()
.returning(|_, _| Ok(vec![1, 2, 3, 4, 5]));
let mut runtime = Runtime::new(
Program::new(vec![], 0, 0),
None,
AthenaCoreOpts::default(),
None,
);
runtime.register_hook(7, Box::new(hook_mock)).unwrap();

let result = SyscallWrite {}.execute(&mut SyscallContext { rt: &mut runtime }, 7, 0);
result.unwrap();
let result = runtime.state.input_stream.pop().unwrap();
assert_eq!(vec![1, 2, 3, 4, 5], result);
}
}

0 comments on commit 3f2b321

Please sign in to comment.