Skip to content

Commit

Permalink
Implement MX lock monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
Dennis Diatlov committed Jul 24, 2023
1 parent 291b86f commit 185154b
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 49 deletions.
16 changes: 10 additions & 6 deletions examples/waiter/src/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,17 @@ async fn main() {
exec::wake(msg_id.into()).expect("Failed to wake up the message");
}
Command::MxLock(lock_duration, continuation) => {
let lock_guard = unsafe {
MUTEX
.lock()
.own_up_for(lock_duration)
.expect("Failed to set mx ownership duration")
.await
let lock = if let Some(lock_duration) = lock_duration {
unsafe {
MUTEX
.lock()
.own_up_for(lock_duration)
.expect("Failed to set mx lock ownership duration")
}
} else {
unsafe { MUTEX.lock() }
};
let lock_guard = lock.await;
process_mx_lock_continuation(
unsafe { &mut MUTEX_LOCK_GUARD },
lock_guard,
Expand Down
2 changes: 1 addition & 1 deletion examples/waiter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub enum Command {
ReplyAndWait(WaitSubcommand),
SleepFor(Vec<u32>, SleepForWaitType),
WakeUp([u8; 32]),
MxLock(u32, MxLockContinuation),
MxLock(Option<u32>, MxLockContinuation),
MxLockStaticAccess(LockStaticAccessSubcommand),
RwLock(RwLockType, RwLockContinuation),
RwLockStaticAccess(RwLockType, LockStaticAccessSubcommand),
Expand Down
2 changes: 1 addition & 1 deletion examples/waiter/tests/mx_lock_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn init_fixture(system: &System) -> (Program<'_>, MessageId) {
let lock_result = program.send(
USER_ID,
Command::MxLock(
u32::MAX,
None,
MxLockContinuation::General(LockContinuation::MoveToStatic),
),
);
Expand Down
81 changes: 57 additions & 24 deletions gstd/src/async_runtime/locks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,32 @@
use crate::{
config::WaitType,
errors::{Error, Result},
exec, BTreeMap, Config, MessageId,
exec,
lock::MutexId,
BTreeMap, BlockCount, BlockNumber, Config, MessageId,
};
use core::cmp::Ordering;
use hashbrown::HashMap;

/// Type of wait locks.
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum LockType {
WaitFor(u32),
WaitUpTo(u32),
WaitFor(BlockCount),
WaitUpTo(BlockCount),
}

/// Wait lock
#[derive(Debug, PartialEq, Eq)]
pub struct Lock {
/// The start block number of this lock.
pub at: u32,
pub at: BlockNumber,
/// The type of this lock.
ty: LockType,
}

impl Lock {
/// Wait for
pub fn exactly(b: u32) -> Result<Self> {
pub fn exactly(b: BlockCount) -> Result<Self> {
if b == 0 {
return Err(Error::EmptyWaitDuration);
}
Expand All @@ -54,7 +56,7 @@ impl Lock {
}

/// Wait up to
pub fn up_to(b: u32) -> Result<Self> {
pub fn up_to(b: BlockCount) -> Result<Self> {
if b == 0 {
return Err(Error::EmptyWaitDuration);
}
Expand Down Expand Up @@ -86,14 +88,14 @@ impl Lock {
}

/// Gets the deadline of the current lock.
pub fn deadline(&self) -> u32 {
pub fn deadline(&self) -> BlockNumber {
match &self.ty {
LockType::WaitFor(d) | LockType::WaitUpTo(d) => self.at.saturating_add(*d),
}
}

/// Check if this lock is timed out.
pub fn timeout(&self) -> Option<(u32, u32)> {
pub fn timeout(&self) -> Option<(BlockNumber, BlockNumber)> {
let current = exec::block_height();
let expected = self.deadline();

Expand Down Expand Up @@ -134,8 +136,12 @@ impl Default for LockType {

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
enum LockContext {
// Used for waiting a reply to message 'MessageId'
ReplyTo(MessageId),
Sleep(u32),
// Used for sending a message to sleep until block 'BlockNumber'
Sleep(BlockNumber),
// Used for waking up a message for an attempt to seize lock for mutex 'MutexId'
MxLockMonitor(MutexId),
}

/// DoubleMap for wait locks.
Expand All @@ -145,7 +151,7 @@ pub struct LocksMap(HashMap<MessageId, BTreeMap<LockContext, Lock>>);
impl LocksMap {
/// Trigger waiting for the message.
pub fn wait(&mut self, message_id: MessageId) {
let map = self.0.entry(message_id).or_insert_with(Default::default);
let map = self.message_locks(message_id);
if map.is_empty() {
// If there is no `waiting_reply_to` id specified, use
// the message id as the key of the message lock.
Expand All @@ -172,35 +178,58 @@ impl LocksMap {

/// Lock message.
pub fn lock(&mut self, message_id: MessageId, waiting_reply_to: MessageId, lock: Lock) {
let locks = self.0.entry(message_id).or_insert_with(Default::default);
locks.insert(LockContext::ReplyTo(waiting_reply_to), lock);
self.message_locks(message_id)
.insert(LockContext::ReplyTo(waiting_reply_to), lock);
}

/// Remove message lock.
pub fn remove(&mut self, message_id: MessageId, waiting_reply_to: MessageId) {
let locks = self.0.entry(message_id).or_insert_with(Default::default);
locks.remove(&LockContext::ReplyTo(waiting_reply_to));
self.message_locks(message_id)
.remove(&LockContext::ReplyTo(waiting_reply_to));
}

/// Inserts a lock for putting a message into sleep.
pub fn insert_sleep(&mut self, message_id: MessageId, until_block: u32) {
let locks = self.0.entry(message_id).or_insert_with(Default::default);
pub fn insert_sleep(&mut self, message_id: MessageId, wake_up_at: BlockNumber) {
let locks = self.message_locks(message_id);
let current_block = exec::block_height();
if current_block < until_block {
if current_block < wake_up_at {
locks.insert(
LockContext::Sleep(until_block),
Lock::exactly(until_block - current_block)
LockContext::Sleep(wake_up_at),
Lock::exactly(wake_up_at - current_block)
.expect("Never fails with block count > 0"),
);
} else {
locks.remove(&LockContext::Sleep(until_block));
locks.remove(&LockContext::Sleep(wake_up_at));
}
}

/// Removes a sleep lock.
pub fn remove_sleep(&mut self, message_id: MessageId, until_block: u32) {
let locks = self.0.entry(message_id).or_insert_with(Default::default);
locks.remove(&LockContext::Sleep(until_block));
pub fn remove_sleep(&mut self, message_id: MessageId, wake_up_at: BlockNumber) {
self.message_locks(message_id)
.remove(&LockContext::Sleep(wake_up_at));
}

pub(crate) fn insert_mx_lock_monitor(
&mut self,
message_id: MessageId,
mutex_id: MutexId,
wake_up_at: BlockNumber,
) {
let locks = self.message_locks(message_id);
locks.insert(
LockContext::MxLockMonitor(mutex_id),
Lock::exactly(
wake_up_at
.checked_sub(exec::block_height())
.expect("Value of after_block must be greater than current block"),
)
.expect("Never fails with block count > 0"),
);
}

pub(crate) fn remove_mx_lock_monitor(&mut self, message_id: MessageId, mutex_id: MutexId) {
self.message_locks(message_id)
.remove(&LockContext::MxLockMonitor(mutex_id));
}

pub fn remove_message_entry(&mut self, message_id: MessageId) {
Expand All @@ -217,11 +246,15 @@ impl LocksMap {
&mut self,
message_id: MessageId,
waiting_reply_to: MessageId,
) -> Option<(u32, u32)> {
) -> Option<(BlockNumber, BlockNumber)> {
self.0.get(&message_id).and_then(|locks| {
locks
.get(&LockContext::ReplyTo(waiting_reply_to))
.and_then(|l| l.timeout())
})
}

fn message_locks(&mut self, message_id: MessageId) -> &mut BTreeMap<LockContext, Lock> {
self.0.entry(message_id).or_insert_with(Default::default)
}
}
12 changes: 12 additions & 0 deletions gstd/src/lock/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ impl AccessQueue {
inner.as_ref().map_or(false, |v| v.contains(message_id))
}

pub fn len(&self) -> usize {
let inner = unsafe { &*self.0.get() };

inner.as_ref().map_or(0, |v| v.len())
}

pub fn first(&self) -> Option<&MessageId> {
let inner = unsafe { &*self.0.get() };

inner.as_ref().and_then(|v| v.front())
}

pub const fn new() -> Self {
AccessQueue(UnsafeCell::new(None))
}
Expand Down
2 changes: 2 additions & 0 deletions gstd/src/lock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ pub use self::{
mutex::{Mutex, MutexGuard, MutexLockFuture},
rwlock::{RwLock, RwLockReadFuture, RwLockReadGuard, RwLockWriteFuture, RwLockWriteGuard},
};

pub(crate) use self::mutex::MutexId;
74 changes: 69 additions & 5 deletions gstd/src/lock/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ use core::{
task::{Context, Poll},
};

static mut NEXT_MUTEX_ID: MutexId = MutexId::new();

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(crate) struct MutexId(u32);

impl MutexId {
pub const fn new() -> Self {
MutexId(1)
}

pub fn next(self) -> Self {
let id = self.0.wrapping_add(1);
MutexId(if id == 0 { 1 } else { id })
}
}

/// A mutual exclusion primitive useful for protecting shared data.
///
/// This mutex will block the execution waiting for the lock to become
Expand Down Expand Up @@ -81,6 +97,7 @@ use core::{
/// # fn main() {}
/// ```
pub struct Mutex<T> {
id: UnsafeCell<Option<MutexId>>,
locked: UnsafeCell<Option<(MessageId, BlockNumber)>>,
value: UnsafeCell<T>,
queue: AccessQueue,
Expand All @@ -90,6 +107,7 @@ impl<T> Mutex<T> {
/// Create a new mutex in an unlocked state ready for use.
pub const fn new(t: T) -> Mutex<T> {
Mutex {
id: UnsafeCell::new(None),
value: UnsafeCell::new(t),
locked: UnsafeCell::new(None),
queue: AccessQueue::new(),
Expand All @@ -107,10 +125,20 @@ impl<T> Mutex<T> {
/// of scope, the mutex will be unlocked.
pub fn lock(&self) -> MutexLockFuture<'_, T> {
MutexLockFuture {
mutex_id: self.get_or_set_id(),
mutex: self,
own_up_for: Config::mx_lock_duration(),
}
}

fn get_or_set_id(&self) -> MutexId {
let id = unsafe { &mut *self.id.get() };
*id.get_or_insert_with(|| unsafe {
let id = NEXT_MUTEX_ID;
NEXT_MUTEX_ID = NEXT_MUTEX_ID.next();
id
})
}
}

/// An RAII implementation of a "scoped lock" of a mutex. When this structure is
Expand Down Expand Up @@ -230,6 +258,7 @@ unsafe impl<T> Sync for Mutex<T> {}
/// # fn main() {}
/// ```
pub struct MutexLockFuture<'a, T> {
mutex_id: MutexId,
mutex: &'a Mutex<T>,
own_up_for: BlockCount,
}
Expand All @@ -242,6 +271,7 @@ impl<'a, T> MutexLockFuture<'a, T> {
Err(Error::ZeroMxLockDuration)
} else {
Ok(MutexLockFuture {
mutex_id: self.mutex_id,
mutex: self.mutex,
own_up_for: block_count,
})
Expand All @@ -253,21 +283,48 @@ impl<'a, T> MutexLockFuture<'a, T> {
owner_msg_id: MessageId,
current_block: BlockNumber,
) -> Poll<MutexGuard<'a, T>> {
let owner_deadline_block = current_block.saturating_add(self.own_up_for);
async_runtime::locks().remove_mx_lock_monitor(owner_msg_id, self.mutex_id);
if let Some(next_rival_msg_id) = self.mutex.queue.first() {
// Give the next rival message a chance to own the lock after this owner
// exceeds the lock ownership duration
async_runtime::locks().insert_mx_lock_monitor(
*next_rival_msg_id,
self.mutex_id,
owner_deadline_block,
);
}
let locked_by = unsafe { &mut *self.mutex.locked.get() };
*locked_by = Some((owner_msg_id, current_block.saturating_add(self.own_up_for)));
*locked_by = Some((owner_msg_id, owner_deadline_block));
Poll::Ready(MutexGuard {
mutex: self.mutex,
holder_msg_id: owner_msg_id,
})
}

fn queue_for_lock_ownership(&mut self, rival_msg_id: MessageId) -> Poll<MutexGuard<'a, T>> {
fn queue_for_lock_ownership(
&mut self,
rival_msg_id: MessageId,
owner_deadline_block: Option<BlockNumber>,
) -> Poll<MutexGuard<'a, T>> {
// If the message is already in the access queue, and we come here,
// it means the message has just been woken up from the waitlist.
// In that case we do not want to register yet another access attempt
// and just go back to the waitlist.
// and just go back to the waitlist
if !self.mutex.queue.contains(&rival_msg_id) {
self.mutex.queue.enqueue(rival_msg_id);
if let Some(owner_deadline_block) = owner_deadline_block {
// Lock owner did not know about this message when it was getting into
// lock ownership. We have to take care of ourselves and give us a chance
// to oust the lock owner when the lock ownership duration expires
if self.mutex.queue.len() == 1 {
async_runtime::locks().insert_mx_lock_monitor(
rival_msg_id,
self.mutex_id,
owner_deadline_block,
);
}
}
}
Poll::Pending
}
Expand Down Expand Up @@ -299,14 +356,21 @@ impl<'a, T> Future for MutexLockFuture<'a, T> {
}
*locked_by = None;
exec::wake(next_msg_id).expect("Failed to wake the message");
return self.get_mut().queue_for_lock_ownership(current_msg_id);
// We have just woken up the next lock owner, but we do not know its ownership
// duration, thus we pass None as owner_deadline_block. The woken up message will
// give us a chance to own the lock itself by registering a lock monitor for us
return self
.get_mut()
.queue_for_lock_ownership(current_msg_id, None);
}

return self
.get_mut()
.acquire_lock_ownership(current_msg_id, current_block);
}
return self.get_mut().queue_for_lock_ownership(current_msg_id);
return self
.get_mut()
.queue_for_lock_ownership(current_msg_id, Some(deadline_block));
}

self.get_mut()
Expand Down
Loading

0 comments on commit 185154b

Please sign in to comment.