diff --git a/src/lib.rs b/src/lib.rs index f3c2d21..c5a1257 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -148,12 +148,6 @@ pub enum State { /// The responder sent a response. Responded = 4, - #[doc(hidden)] - CancelingRequested = 10, - #[doc(hidden)] - CancelingBuildingResponse = 11, - /// The requester canceled the request. Responder needs to acknowledge to return to `Idle` - /// state. Canceled = 12, } @@ -171,11 +165,7 @@ impl From for State { 2 => State::Requested, 3 => State::BuildingResponse, 4 => State::Responded, - - 10 => State::CancelingRequested, - 11 => State::CancelingBuildingResponse, 12 => State::Canceled, - _ => State::Idle, } } @@ -401,6 +391,12 @@ impl Channel { pub fn split(&self) -> Option<(Requester<'_, Rq, Rp>, Responder<'_, Rq, Rp>)> { Some((self.requester()?, self.responder()?)) } + + fn transition(&self, from: State, to: State) -> bool { + self.state + .compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + } } impl Default for Channel { @@ -426,12 +422,8 @@ impl<'i, Rq, Rp> Drop for Requester<'i, Rq, Rp> { } impl<'i, Rq, Rp> Requester<'i, Rq, Rp> { - #[inline] - fn transition(&self, from: State, to: State) -> bool { + pub fn channel(&self) -> &'i Channel { self.channel - .state - .compare_exchange(from as u8, to as u8, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() } #[cfg(not(loom))] @@ -505,23 +497,17 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> { /// /// In other cases (`Idle` or `Reponsed`) there is nothing to cancel and we fail. pub fn cancel(&mut self) -> Result, Error> { - // we canceled before the responder was even aware of the request. - if self.transition(State::Requested, State::CancelingRequested) { - self.channel - .state - .store(State::Idle as u8, Ordering::Release); - return Ok(Some(unsafe { self.with_data_mut(|i| i.take_rq()) })); + if self + .channel + .transition(State::BuildingResponse, State::Canceled) + { + // we canceled after the responder took the request, but before they answered. + return Ok(None); } - // we canceled after the responder took the request, but before they answered. - if self.transition(State::BuildingResponse, State::CancelingRequested) { - // this may not yet be None in case the responder switched state to - // BuildingResponse but did not take out the request yet. - // assert!(self.data.is_none()); - self.channel - .state - .store(State::Canceled as u8, Ordering::Release); - return Ok(None); + if self.channel.transition(State::Requested, State::Idle) { + // we canceled before the responder was even aware of the request. + return Ok(Some(unsafe { self.with_data_mut(|i| i.take_rq()) })); } Err(Error) @@ -534,7 +520,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> { // this is likely correct #[cfg(not(loom))] pub fn response(&self) -> Result<&Rp, Error> { - if self.transition(State::Responded, State::Responded) { + if self.channel.transition(State::Responded, State::Responded) { Ok(unsafe { self.data().rp_ref() }) } else { Err(Error) @@ -545,7 +531,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> { /// /// This may be called multiple times. pub fn with_response(&self, f: impl FnOnce(&Rp) -> R) -> Result { - if self.transition(State::Responded, State::Responded) { + if self.channel.transition(State::Responded, State::Responded) { Ok(unsafe { self.with_data(|i| f(i.rp_ref())) }) } else { Err(Error) @@ -560,7 +546,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> { // It is a logic error to call this method if we're Idle or Canceled, but // it seems unnecessary to model this. pub fn take_response(&mut self) -> Option { - if self.transition(State::Responded, State::Idle) { + if self.channel.transition(State::Responded, State::Idle) { Some(unsafe { self.with_data_mut(|i| i.take_rp()) }) } else { None @@ -576,8 +562,10 @@ where /// /// This is usefull to build large structures in-place pub fn with_request_mut(&mut self, f: impl FnOnce(&mut Rq) -> R) -> Result { - if self.transition(State::Idle, State::BuildingRequest) - || self.transition(State::BuildingRequest, State::BuildingRequest) + if self.channel.transition(State::Idle, State::BuildingRequest) + || self + .channel + .transition(State::BuildingRequest, State::BuildingRequest) { let res = unsafe { self.with_data_mut(|i| { @@ -600,8 +588,10 @@ where // this is likely correct #[cfg(not(loom))] pub fn request_mut(&mut self) -> Result<&mut Rq, Error> { - if self.transition(State::Idle, State::BuildingRequest) - || self.transition(State::BuildingRequest, State::BuildingRequest) + if self.channel.transition(State::Idle, State::BuildingRequest) + || self + .channel + .transition(State::BuildingRequest, State::BuildingRequest) { unsafe { self.with_data_mut(|i| { @@ -620,7 +610,9 @@ where /// `with_request_mut`. pub fn send_request(&mut self) -> Result<(), Error> { if State::BuildingRequest == self.channel.state.load(Ordering::Acquire) - && self.transition(State::BuildingRequest, State::Requested) + && self + .channel + .transition(State::BuildingRequest, State::Requested) { Ok(()) } else { @@ -647,12 +639,8 @@ impl<'i, Rq, Rp> Drop for Responder<'i, Rq, Rp> { } impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { - #[inline] - fn transition(&self, from: State, to: State) -> bool { + pub fn channel(&self) -> &'i Channel { self.channel - .state - .compare_exchange(from as u8, to as u8, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() } #[cfg(not(loom))] @@ -701,7 +689,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { /// This may be called only once as it move the state to BuildingResponse. /// If you need copies, use `take_request` pub fn with_request(&self, f: impl FnOnce(&Rq) -> R) -> Result { - if self.transition(State::Requested, State::BuildingResponse) { + if self + .channel + .transition(State::Requested, State::BuildingResponse) + { Ok(unsafe { self.with_data(|i| f(i.rq_ref())) }) } else { Err(Error) @@ -715,7 +706,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { // this is likely correct #[cfg(not(loom))] pub fn request(&self) -> Result<&Rq, Error> { - if self.transition(State::Requested, State::BuildingResponse) { + if self + .channel + .transition(State::Requested, State::BuildingResponse) + { Ok(unsafe { self.data().rq_ref() }) } else { Err(Error) @@ -727,7 +721,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { /// This may be called only once as it move the state to BuildingResponse. /// If you need copies, clone the request. pub fn take_request(&mut self) -> Option { - if self.transition(State::Requested, State::BuildingResponse) { + if self + .channel + .transition(State::Requested, State::BuildingResponse) + { Some(unsafe { self.with_data_mut(|i| i.take_rq()) }) } else { None @@ -743,17 +740,7 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { // // It is a logic error to call this method if there is no pending cancellation. pub fn acknowledge_cancel(&self) -> Result<(), Error> { - if self - .channel - .state - .compare_exchange( - State::Canceled as u8, - State::Idle as u8, - Ordering::SeqCst, - Ordering::SeqCst, - ) - .is_ok() - { + if self.channel.transition(State::Canceled, State::Idle) { Ok(()) } else { Err(Error) @@ -770,10 +757,14 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> { unsafe { self.with_data_mut(|i| *i = Message::from_rp(response)); } - self.channel - .state - .store(State::Responded as u8, Ordering::Release); - Ok(()) + if self + .channel + .transition(State::BuildingResponse, State::Responded) + { + Ok(()) + } else { + Err(Error) + } } else { Err(Error) } @@ -788,8 +779,12 @@ where /// /// This is usefull to build large structures in-place pub fn with_response_mut(&mut self, f: impl FnOnce(&mut Rp) -> R) -> Result { - if self.transition(State::Requested, State::BuildingResponse) - || self.transition(State::BuildingResponse, State::BuildingResponse) + if self + .channel + .transition(State::Requested, State::BuildingResponse) + || self + .channel + .transition(State::BuildingResponse, State::BuildingResponse) { let res = unsafe { self.with_data_mut(|i| { @@ -812,8 +807,12 @@ where // this is likely correct #[cfg(not(loom))] pub fn response_mut(&mut self) -> Result<&mut Rp, Error> { - if self.transition(State::Requested, State::BuildingResponse) - || self.transition(State::BuildingResponse, State::BuildingResponse) + if self + .channel + .transition(State::Requested, State::BuildingResponse) + || self + .channel + .transition(State::BuildingResponse, State::BuildingResponse) { unsafe { self.with_data_mut(|i| { @@ -832,7 +831,9 @@ where /// `with_response_mut`. pub fn send_response(&mut self) -> Result<(), Error> { if State::BuildingResponse == self.channel.state.load(Ordering::Acquire) - && self.transition(State::BuildingResponse, State::Responded) + && self + .channel + .transition(State::BuildingResponse, State::Responded) { Ok(()) } else { diff --git a/tests/loom.rs b/tests/loom.rs index e688186..d944980 100644 --- a/tests/loom.rs +++ b/tests/loom.rs @@ -10,10 +10,10 @@ use interchange::{Channel, Requester, Responder}; use std::sync::atomic::Ordering::Acquire; use std::sync::atomic::{AtomicBool, Ordering::Release}; -static BRANCHES_USED: [AtomicBool; 6] = { +static BRANCHES_USED: [AtomicBool; 14] = { #[allow(clippy::declare_interior_mutable_const)] const ATOMIC_BOOL_INIT: AtomicBool = AtomicBool::new(false); - [ATOMIC_BOOL_INIT; 6] + [ATOMIC_BOOL_INIT; 14] }; #[cfg(loom)] @@ -51,42 +51,73 @@ fn test_function() { fn requester_thread(mut requester: Requester<'static, u64, u64>) -> Option<()> { requester.request(53).unwrap(); - requester.with_response(|r| assert_eq!(*r, 63)).ok()?; - requester.with_response(|r| assert_eq!(*r, 63)).ok()?; + match requester.cancel() { + Ok(Some(53) | None) => { + BRANCHES_USED[0].store(true, Release); + return None; + } + Ok(_) => panic!("Invalid state"), + Err(_) => { + BRANCHES_USED[1].store(true, Release); + } + } + requester + .with_response(|r| { + BRANCHES_USED[2].store(true, Release); + assert_eq!(*r, 63) + }) + .ok() + .or_else(|| { + BRANCHES_USED[3].store(true, Release); + None + })?; + requester.with_response(|r| assert_eq!(*r, 63)).unwrap(); requester.take_response().unwrap(); requester.with_request_mut(|r| *r = 51).unwrap(); requester.send_request().unwrap(); thread::yield_now(); match requester.cancel() { - Ok(Some(51) | None) => BRANCHES_USED[0].store(true, Release), + Ok(Some(51) | None) => BRANCHES_USED[4].store(true, Release), Ok(_) => panic!("Invalid state"), Err(_) => { - BRANCHES_USED[1].store(true, Release); - assert_eq!(requester.take_response().unwrap(), 79); + BRANCHES_USED[5].store(true, Release); + match requester.take_response() { + Some(i) => { + assert_eq!(i, 79); + BRANCHES_USED[6].store(true, Release); + } + None => BRANCHES_USED[7].store(true, Release), + } } } - BRANCHES_USED[4].store(true, Release); + BRANCHES_USED[8].store(true, Release); None } fn responder_thread(mut responder: Responder<'static, u64, u64>) -> Option<()> { - let req = responder.take_request()?; + let req = responder.take_request().or_else(|| { + BRANCHES_USED[9].store(true, Release); + None + })?; assert_eq!(req, 53); - responder.respond(req + 10).unwrap(); + responder.respond(req + 10).ok().or_else(|| { + BRANCHES_USED[10].store(true, Release); + None + })?; thread::yield_now(); responder .with_request(|r| { - BRANCHES_USED[2].store(true, Release); + BRANCHES_USED[11].store(true, Release); assert_eq!(*r, 51) }) .map(|_| assert!(responder.with_request(|_| {}).is_err())) .or_else(|_| { - BRANCHES_USED[3].store(true, Release); + BRANCHES_USED[12].store(true, Release); responder.acknowledge_cancel() }) .ok()?; responder.with_response_mut(|r| *r = 79).ok(); responder.send_response().ok(); - BRANCHES_USED[5].store(true, Release); + BRANCHES_USED[13].store(true, Release); None }