From 771df58ed7fa8976f18a865496e1f3ee0dae73f4 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Tue, 1 Oct 2024 11:55:56 +0200 Subject: [PATCH] Min max protocol version + better errors (#18) * Expose min max version * Improve the error situation --- Cargo.toml | 2 +- src/lib.rs | 68 +++++++++++++++++++++-------- src/service_protocol/encoding.rs | 16 +++---- src/service_protocol/messages.rs | 32 +++++++------- src/service_protocol/version.rs | 10 +++-- src/tests/failures.rs | 6 +-- src/tests/mod.rs | 14 +++--- src/tests/promise.rs | 22 ++++------ src/tests/run.rs | 23 +++------- src/vm/context.rs | 6 +-- src/vm/errors.rs | 32 +++++++------- src/vm/mod.rs | 50 +++++++++------------ src/vm/transitions/async_results.rs | 8 ++-- src/vm/transitions/combinators.rs | 4 +- src/vm/transitions/input.rs | 16 +++---- src/vm/transitions/journal.rs | 31 +++++++------ src/vm/transitions/mod.rs | 10 ++--- src/vm/transitions/terminal.rs | 10 ++--- 18 files changed, 186 insertions(+), 174 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 30d4567..a5a75bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ request_identity = ["dep:ring", "dep:sha2", "dep:jsonwebtoken", "dep:bs58"] sha2_random_seed = ["dep:sha2"] [dependencies] -thiserror = "1.0.61" +thiserror = "1.0.64" prost = "0.13.2" bytes = "1.6" bytes-utils = "0.1.4" diff --git a/src/lib.rs b/src/lib.rs index ced8a6d..7183281 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,12 +11,20 @@ use std::fmt; use std::time::Duration; pub use crate::retries::RetryPolicy; -use crate::vm::AsyncResultAccessTrackerInner; pub use headers::HeaderMap; #[cfg(feature = "request_identity")] pub use request_identity::*; +pub use service_protocol::Version; pub use vm::CoreVM; +// Re-export only some stuff from vm::errors +pub mod error { + pub use crate::vm::errors::codes; + pub use crate::vm::errors::InvocationErrorCode; +} + +use crate::vm::AsyncResultAccessTrackerInner; + #[derive(Debug, Eq, PartialEq)] pub struct Header { pub key: Cow<'static, str>, @@ -27,6 +35,7 @@ pub struct Header { pub struct ResponseHead { pub status_code: u16, pub headers: Vec
, + pub version: Version, } #[derive(Debug, Clone, Copy, thiserror::Error)] @@ -35,21 +44,25 @@ pub struct SuspendedError; #[derive(Debug, Clone, thiserror::Error)] #[error("VM Error [{code}]: {message}. Description: {description}")] -pub struct VMError { +pub struct Error { code: u16, message: Cow<'static, str>, description: Cow<'static, str>, } -impl VMError { +impl Error { pub fn new(code: impl Into, message: impl Into>) -> Self { - VMError { + Error { code: code.into(), message: message.into(), description: Default::default(), } } + pub fn internal(message: impl Into>) -> Self { + Self::new(error::codes::INTERNAL, message) + } + pub fn code(&self) -> u16 { self.code } @@ -61,6 +74,30 @@ impl VMError { pub fn description(&self) -> &str { &self.description } + + pub fn with_description(mut self, description: impl Into>) -> Self { + self.description = description.into(); + self + } + + /// Append the given description to the original one, in case the code is the same + pub fn append_description_for_code( + mut self, + code: impl Into, + description: impl Into>, + ) -> Self { + let c = code.into(); + if self.code == c { + if self.description.is_empty() { + self.description = description.into(); + } else { + self.description = format!("{}. {}", self.description, description.into()).into(); + } + self + } else { + self + } + } } #[derive(Debug, Clone, thiserror::Error)] @@ -68,7 +105,7 @@ pub enum SuspendedOrVMError { #[error(transparent)] Suspended(SuspendedError), #[error(transparent)] - VM(VMError), + VM(Error), } #[derive(Debug, Eq, PartialEq)] @@ -107,7 +144,7 @@ pub enum Value { /// a void/None/undefined success Void, Success(Bytes), - Failure(Failure), + Failure(TerminalFailure), /// Only returned for get_state_keys StateKeys(Vec), CombinatorResult(Vec), @@ -115,7 +152,7 @@ pub enum Value { /// Terminal failure #[derive(Debug, Clone, Eq, PartialEq)] -pub struct Failure { +pub struct TerminalFailure { pub code: u16, pub message: String, } @@ -137,17 +174,17 @@ pub enum RunEnterResult { #[derive(Debug, Clone)] pub enum RunExitResult { Success(Bytes), - TerminalFailure(Failure), + TerminalFailure(TerminalFailure), RetryableFailure { attempt_duration: Duration, - failure: Failure, + error: Error, }, } #[derive(Debug, Clone)] pub enum NonEmptyValue { Success(Bytes), - Failure(Failure), + Failure(TerminalFailure), } impl From for Value { @@ -165,7 +202,7 @@ pub enum TakeOutputResult { EOF, } -pub type VMResult = Result; +pub type VMResult = Result; pub struct VMOptions { /// If true, false when two concurrent async results are awaited at the same time. If false, just log it. @@ -193,12 +230,7 @@ pub trait VM: Sized { // --- Errors - fn notify_error( - &mut self, - message: Cow<'static, str>, - description: Cow<'static, str>, - next_retry_delay: Option, - ); + fn notify_error(&mut self, error: Error, next_retry_delay: Option); // --- Output stream @@ -206,7 +238,7 @@ pub trait VM: Sized { // --- Execution start waiting point - fn is_ready_to_execute(&self) -> Result; + fn is_ready_to_execute(&self) -> VMResult; // --- Async results diff --git a/src/service_protocol/encoding.rs b/src/service_protocol/encoding.rs index aac9dd9..f41729b 100644 --- a/src/service_protocol/encoding.rs +++ b/src/service_protocol/encoding.rs @@ -38,9 +38,9 @@ impl Encoder { pub fn new(service_protocol_version: Version) -> Self { assert_eq!( service_protocol_version, - Version::latest(), + Version::maximum_supported_version(), "Encoder only supports service protocol version {:?}", - Version::latest() + Version::maximum_supported_version() ); Self {} } @@ -107,9 +107,9 @@ impl Decoder { pub fn new(service_protocol_version: Version) -> Self { assert_eq!( service_protocol_version, - Version::latest(), + Version::maximum_supported_version(), "Decoder only supports service protocol version {:?}", - Version::latest() + Version::maximum_supported_version() ); Self { buf: SegmentedBuf::new(), @@ -185,8 +185,8 @@ mod tests { #[test] fn fill_decoder_with_several_messages() { - let encoder = Encoder::new(Version::latest()); - let mut decoder = Decoder::new(Version::latest()); + let encoder = Encoder::new(Version::maximum_supported_version()); + let mut decoder = Decoder::new(Version::maximum_supported_version()); let expected_msg_0 = messages::StartMessage { id: "key".into(), @@ -260,8 +260,8 @@ mod tests { } fn partial_decoding_test(split_index: usize) { - let encoder = Encoder::new(Version::latest()); - let mut decoder = Decoder::new(Version::latest()); + let encoder = Encoder::new(Version::maximum_supported_version()); + let mut decoder = Decoder::new(Version::maximum_supported_version()); let expected_msg = messages::InputEntryMessage { value: Bytes::from_static("input".as_bytes()), diff --git a/src/service_protocol/messages.rs b/src/service_protocol/messages.rs index 37eef62..d1da271 100644 --- a/src/service_protocol/messages.rs +++ b/src/service_protocol/messages.rs @@ -1,7 +1,7 @@ use crate::service_protocol::messages::get_state_keys_entry_message::StateKeys; use crate::service_protocol::{MessageHeader, MessageType}; use crate::vm::errors::{DecodeStateKeysProst, DecodeStateKeysUtf8, EmptyStateKeys}; -use crate::{NonEmptyValue, VMError, Value}; +use crate::{Error, NonEmptyValue, Value}; use paste::paste; use prost::Message; @@ -25,7 +25,7 @@ pub trait EntryMessageHeaderEq { pub trait CompletableEntryMessage: RestateMessage + EntryMessage + EntryMessageHeaderEq { fn is_completed(&self) -> bool; - fn into_completion(self) -> Result, VMError>; + fn into_completion(self) -> Result, Error>; fn completion_parsing_hint() -> CompletionParsingHint; } @@ -74,7 +74,7 @@ macro_rules! impl_message_traits { self.result.is_some() } - fn into_completion(self) -> Result, VMError> { + fn into_completion(self) -> Result, Error> { self.result.map(TryInto::try_into).transpose() } @@ -133,7 +133,7 @@ impl CompletableEntryMessage for GetStateKeysEntryMessage { self.result.is_some() } - fn into_completion(self) -> Result, VMError> { + fn into_completion(self) -> Result, Error> { self.result.map(TryInto::try_into).transpose() } @@ -255,7 +255,7 @@ impl EntryMessageHeaderEq for CombinatorEntryMessage { // --- Completion extraction impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: get_state_entry_message::Result) -> Result { Ok(match value { @@ -267,7 +267,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: get_state_keys_entry_message::Result) -> Result { match value { @@ -286,7 +286,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: sleep_entry_message::Result) -> Result { Ok(match value { @@ -297,7 +297,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: call_entry_message::Result) -> Result { Ok(match value { @@ -308,7 +308,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: awakeable_entry_message::Result) -> Result { Ok(match value { @@ -319,7 +319,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: get_promise_entry_message::Result) -> Result { Ok(match value { @@ -330,7 +330,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: peek_promise_entry_message::Result) -> Result { Ok(match value { @@ -342,7 +342,7 @@ impl TryFrom for Value { } impl TryFrom for Value { - type Error = VMError; + type Error = Error; fn try_from(value: complete_promise_entry_message::Result) -> Result { Ok(match value { @@ -363,8 +363,8 @@ impl From for NonEmptyValue { // --- Other conversions -impl From for Failure { - fn from(value: crate::Failure) -> Self { +impl From for Failure { + fn from(value: crate::TerminalFailure) -> Self { Self { code: value.code as u32, message: value.message, @@ -372,7 +372,7 @@ impl From for Failure { } } -impl From for crate::Failure { +impl From for crate::TerminalFailure { fn from(value: Failure) -> Self { Self { code: value.code as u16, @@ -391,7 +391,7 @@ pub(crate) enum CompletionParsingHint { } impl CompletionParsingHint { - pub(crate) fn parse(self, result: completion_message::Result) -> Result { + pub(crate) fn parse(self, result: completion_message::Result) -> Result { match self { CompletionParsingHint::StateKeys => match result { completion_message::Result::Empty(_) => Err(EmptyStateKeys.into()), diff --git a/src/service_protocol/version.rs b/src/service_protocol/version.rs index 84d8264..f35ec94 100644 --- a/src/service_protocol/version.rs +++ b/src/service_protocol/version.rs @@ -3,8 +3,8 @@ use std::str::FromStr; #[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq)] pub enum Version { - V1, - V2, + V1 = 1, + V2 = 2, } const CONTENT_TYPE_V1: &str = "application/vnd.restate.invocation.v1"; @@ -18,7 +18,11 @@ impl Version { } } - pub const fn latest() -> Self { + pub const fn minimum_supported_version() -> Self { + Version::V2 + } + + pub const fn maximum_supported_version() -> Self { Version::V2 } } diff --git a/src/tests/failures.rs b/src/tests/failures.rs index 0841e99..2baf381 100644 --- a/src/tests/failures.rs +++ b/src/tests/failures.rs @@ -8,8 +8,8 @@ use test_log::test; #[test] fn got_closed_stream_before_end_of_replay() { - let mut vm = CoreVM::mock_init(Version::latest()); - let encoder = Encoder::new(Version::latest()); + let mut vm = CoreVM::mock_init(Version::maximum_supported_version()); + let encoder = Encoder::new(Version::maximum_supported_version()); vm.notify_input(encoder.encode(&StartMessage { id: Bytes::from_static(b"123"), @@ -86,7 +86,7 @@ fn one_way_call_entry_mismatch() { fn test_entry_mismatch( expected: M, actual: M, - user_code: impl FnOnce(&mut CoreVM) -> Result, + user_code: impl FnOnce(&mut CoreVM) -> Result, ) { let mut output = VMTestCase::new() .input(StartMessage { diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3e1bc96..659b7d3 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -50,8 +50,8 @@ struct VMTestCase { impl VMTestCase { fn new() -> Self { Self { - encoder: Encoder::new(Version::latest()), - vm: CoreVM::mock_init(Version::latest()), + encoder: Encoder::new(Version::maximum_supported_version()), + vm: CoreVM::mock_init(Version::maximum_supported_version()), } } @@ -85,7 +85,7 @@ struct OutputIterator(Decoder); impl OutputIterator { fn collect_vm(vm: &mut impl VM) -> Self { - let mut decoder = Decoder::new(Version::latest()); + let mut decoder = Decoder::new(Version::maximum_supported_version()); while let TakeOutputResult::Buffer(b) = vm.take_output() { decoder.push(b); } @@ -113,8 +113,8 @@ impl Iterator for OutputIterator { // --- Matchers /// Matcher for VMError -pub fn eq_vm_error(vm_error: VMError) -> impl Matcher { - pat!(VMError { +pub fn eq_vm_error(vm_error: Error) -> impl Matcher { + pat!(Error { code: eq(vm_error.code), message: eq(vm_error.message), description: eq(vm_error.description) @@ -122,7 +122,7 @@ pub fn eq_vm_error(vm_error: VMError) -> impl Matcher { } /// Matcher for ErrorMessage to equal VMError -pub fn error_message_as_vm_error(vm_error: VMError) -> impl Matcher { +pub fn error_message_as_vm_error(vm_error: Error) -> impl Matcher { pat!(ErrorMessage { code: eq(vm_error.code as u32), message: eq(vm_error.message), @@ -210,7 +210,7 @@ pub fn input_entry_message(b: impl AsRef<[u8]>) -> InputEntryMessage { #[test] fn take_output_on_newly_initialized_vm() { - let mut vm = CoreVM::mock_init(Version::latest()); + let mut vm = CoreVM::mock_init(Version::maximum_supported_version()); assert_that!( vm.take_output(), eq(TakeOutputResult::Buffer(Bytes::default())) diff --git a/src/tests/promise.rs b/src/tests/promise.rs index 70ee543..a7d4fdc 100644 --- a/src/tests/promise.rs +++ b/src/tests/promise.rs @@ -412,13 +412,10 @@ mod complete_promise { entry_index: 1, result: Some(completion_message::Result::Empty(Empty::default())), }) - .run(handler(NonEmptyValue::Failure( - Failure { - code: 500, - message: "my failure".to_owned(), - } - .into(), - ))); + .run(handler(NonEmptyValue::Failure(TerminalFailure { + code: 500, + message: "my failure".to_owned(), + }))); assert_eq!( output @@ -467,13 +464,10 @@ mod complete_promise { message: "cannot write promise".to_owned(), })), }) - .run(handler(NonEmptyValue::Failure( - Failure { - code: 500, - message: "my failure".to_owned(), - } - .into(), - ))); + .run(handler(NonEmptyValue::Failure(TerminalFailure { + code: 500, + message: "my failure".to_owned(), + }))); assert_eq!( output diff --git a/src/tests/run.rs b/src/tests/run.rs index a8b36bb..3f27f54 100644 --- a/src/tests/run.rs +++ b/src/tests/run.rs @@ -209,7 +209,7 @@ fn enter_then_exit_then_ack_with_failure() { ); let handle = vm .sys_run_exit( - RunExitResult::TerminalFailure(Failure { + RunExitResult::TerminalFailure(TerminalFailure { code: 500, message: "my-failure".to_string(), }), @@ -291,15 +291,15 @@ fn enter_then_notify_error() { vm.sys_run_enter("my-side-effect".to_owned()).unwrap() ); vm.notify_error( - Cow::Borrowed("my-error"), - Cow::Borrowed("my-error-description"), + Error::internal(Cow::Borrowed("my-error")) + .with_description(Cow::Borrowed("my-error-description")), None, ); }); assert_that!( output.next_decoded::().unwrap(), - error_message_as_vm_error(VMError { + error_message_as_vm_error(Error { code: 500, message: Cow::Borrowed("my-error"), description: Cow::Borrowed("my-error-description"), @@ -515,10 +515,7 @@ mod retry_policy { let handle = vm .sys_run_exit( RunExitResult::RetryableFailure { - failure: Failure { - code: 500, - message: "my-error".to_string(), - }, + error: Error::internal("my-error"), attempt_duration, }, retry_policy, @@ -581,10 +578,7 @@ mod retry_policy { assert!(vm .sys_run_exit( RunExitResult::RetryableFailure { - failure: Failure { - code: 500, - message: "my-error".to_string(), - }, + error: Error::internal("my-error"), attempt_duration }, retry_policy @@ -691,10 +685,7 @@ mod retry_policy { assert!(vm .sys_run_exit( RunExitResult::RetryableFailure { - failure: Failure { - code: 500, - message: "my-error".to_string(), - }, + error: Error::internal("my-error"), attempt_duration: Duration::from_millis(99) }, RetryPolicy::FixedDelay { diff --git a/src/vm/context.rs b/src/vm/context.rs index 43d5483..5532f19 100644 --- a/src/vm/context.rs +++ b/src/vm/context.rs @@ -3,7 +3,7 @@ use crate::service_protocol::messages::{ WriteableRestateMessage, }; use crate::service_protocol::{Encoder, MessageType, Version}; -use crate::{AsyncResultHandle, AsyncResultState, EntryRetryInfo, VMError, VMOptions, Value}; +use crate::{AsyncResultHandle, AsyncResultState, EntryRetryInfo, Error, VMOptions, Value}; use bytes::Bytes; use bytes_utils::SegmentedBuf; use std::collections::{HashMap, VecDeque}; @@ -108,7 +108,7 @@ impl AsyncResultsState { &mut self, index: u32, completion_parsing_hint: CompletionParsingHint, - ) -> Result<(), VMError> { + ) -> Result<(), Error> { if let Some(unparsed_completion_or_parsing_hint) = self.unparsed_completions_or_parsing_hints.remove(&index) { @@ -134,7 +134,7 @@ impl AsyncResultsState { &mut self, index: u32, result: completion_message::Result, - ) -> Result<(), VMError> { + ) -> Result<(), Error> { if let Some(unparsed_completion_or_parsing_hint) = self.unparsed_completions_or_parsing_hints.remove(&index) { diff --git a/src/vm/errors.rs b/src/vm/errors.rs index ac64fc6..a2be1dc 100644 --- a/src/vm/errors.rs +++ b/src/vm/errors.rs @@ -1,5 +1,5 @@ use crate::service_protocol::{DecodingError, MessageType, UnsupportedVersionError}; -use crate::VMError; +use crate::Error; use std::borrow::Cow; use std::fmt; @@ -66,9 +66,9 @@ pub mod codes { // Const errors -impl VMError { +impl Error { const fn new_const(code: InvocationErrorCode, message: &'static str) -> Self { - VMError { + Error { code: code.0, message: Cow::Borrowed(message), description: Cow::Borrowed(""), @@ -76,50 +76,50 @@ impl VMError { } } -pub const MISSING_CONTENT_TYPE: VMError = VMError::new_const( +pub const MISSING_CONTENT_TYPE: Error = Error::new_const( codes::UNSUPPORTED_MEDIA_TYPE, "Missing content type when invoking", ); -pub const UNEXPECTED_INPUT_MESSAGE: VMError = VMError::new_const( +pub const UNEXPECTED_INPUT_MESSAGE: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "Expected input message to be entry", ); -pub const KNOWN_ENTRIES_IS_ZERO: VMError = - VMError::new_const(codes::INTERNAL, "Known entries is zero, expected >= 1"); +pub const KNOWN_ENTRIES_IS_ZERO: Error = + Error::new_const(codes::INTERNAL, "Known entries is zero, expected >= 1"); -pub const UNEXPECTED_ENTRY_MESSAGE: VMError = VMError::new_const( +pub const UNEXPECTED_ENTRY_MESSAGE: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "Expected entry messages only when waiting replay entries", ); -pub const UNEXPECTED_NONE_RUN_RESULT: VMError = VMError::new_const( +pub const UNEXPECTED_NONE_RUN_RESULT: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "Expected RunEntryMessage to contain a result", ); -pub const EXPECTED_COMPLETION_RESULT: VMError = VMError::new_const( +pub const EXPECTED_COMPLETION_RESULT: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "The completion message MUST contain a result", ); -pub const INSIDE_RUN: VMError = VMError::new_const( +pub const INSIDE_RUN: Error = Error::new_const( codes::INTERNAL, "A syscall was invoked from within a run operation", ); -pub const INVOKED_RUN_EXIT_WITHOUT_ENTER: VMError = VMError::new_const( +pub const INVOKED_RUN_EXIT_WITHOUT_ENTER: Error = Error::new_const( codes::INTERNAL, "Invoked sys_run_exit without invoking sys_run_enter before", ); -pub const INPUT_CLOSED_WHILE_WAITING_ENTRIES: VMError = VMError::new_const( +pub const INPUT_CLOSED_WHILE_WAITING_ENTRIES: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "The input was closed while still waiting to receive all the `known_entries`", ); -pub const BAD_COMBINATOR_ENTRY: VMError = VMError::new_const( +pub const BAD_COMBINATOR_ENTRY: Error = Error::new_const( codes::PROTOCOL_VIOLATION, "The combinator cannot be replayed. This is most likely caused by non deterministic code.", ); @@ -202,9 +202,9 @@ trait WithInvocationErrorCode { fn code(&self) -> InvocationErrorCode; } -impl From for VMError { +impl From for Error { fn from(value: T) -> Self { - VMError::new(value.code().0, value.to_string()) + Error::new(value.code().0, value.to_string()) } } diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 8a36ba2..f5e91ec 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -13,9 +13,9 @@ use crate::vm::context::{EagerGetState, EagerGetStateKeys}; use crate::vm::errors::UnexpectedStateError; use crate::vm::transitions::*; use crate::{ - AsyncResultCombinator, AsyncResultHandle, Header, Input, NonEmptyValue, ResponseHead, + AsyncResultCombinator, AsyncResultHandle, Error, Header, Input, NonEmptyValue, ResponseHead, RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, - VMError, VMOptions, VMResult, Value, + VMOptions, VMResult, Value, }; use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}; use base64::{alphabet, Engine}; @@ -59,7 +59,7 @@ pub(crate) enum State { } impl State { - fn as_unexpected_state(&self, event: &'static str) -> VMError { + fn as_unexpected_state(&self, event: &'static str) -> Error { UnexpectedStateError::new(self.into(), event).into() } } @@ -72,7 +72,7 @@ pub struct CoreVM { // State machine context: Context, - last_transition: Result, + last_transition: Result, } impl CoreVM { @@ -112,11 +112,11 @@ const _: () = is_send::(); impl super::VM for CoreVM { #[instrument(level = "debug", skip_all, ret)] - fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result { + fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result { let version = request_headers .extract(CONTENT_TYPE) .map_err(|e| { - VMError::new( + Error::new( errors::codes::BAD_REQUEST, format!("cannot read '{CONTENT_TYPE}' header: {e:?}"), ) @@ -124,8 +124,8 @@ impl super::VM for CoreVM { .ok_or(errors::MISSING_CONTENT_TYPE)? .parse::()?; - if version != Version::latest() { - return Err(VMError::new( + if version != Version::maximum_supported_version() { + return Err(Error::new( errors::codes::UNSUPPORTED_MEDIA_TYPE, format!("Unsupported protocol version {:?}", version), )); @@ -160,6 +160,7 @@ impl super::VM for CoreVM { key: Cow::Borrowed(CONTENT_TYPE), value: Cow::Borrowed(self.version.content_type()), }], + version: self.version, } } @@ -213,18 +214,9 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn notify_error( - &mut self, - message: Cow<'static, str>, - description: Cow<'static, str>, - next_retry_delay: Option, - ) { + fn notify_error(&mut self, error: Error, next_retry_delay: Option) { let _ = self.do_transition(HitError { - error: VMError { - code: errors::codes::INTERNAL.into(), - message, - description, - }, + error, next_retry_delay, }); } @@ -256,7 +248,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn is_ready_to_execute(&self) -> Result { + fn is_ready_to_execute(&self) -> Result { match &self.last_transition { Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false), Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true), @@ -298,7 +290,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_input(&mut self) -> Result { + fn sys_input(&mut self) -> Result { self.do_transition(SysInput) } @@ -308,7 +300,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_state_get(&mut self, key: String) -> Result { + fn sys_state_get(&mut self, key: String) -> Result { let result = match self.context.eager_state.get(&key) { EagerGetState::Unknown => None, EagerGetState::Empty => Some(get_state_entry_message::Result::Empty(Empty::default())), @@ -354,7 +346,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), VMError> { + fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> { self.context.eager_state.set(key.clone(), value.clone()); self.do_transition(SysNonCompletableEntry( "SysStateSet", @@ -372,7 +364,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_state_clear(&mut self, key: String) -> Result<(), VMError> { + fn sys_state_clear(&mut self, key: String) -> Result<(), Error> { self.context.eager_state.clear(key.clone()); self.do_transition(SysNonCompletableEntry( "SysStateClear", @@ -389,7 +381,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_state_clear_all(&mut self) -> Result<(), VMError> { + fn sys_state_clear_all(&mut self) -> Result<(), Error> { self.context.eager_state.clear_all(); self.do_transition(SysNonCompletableEntry( "SysStateClearAll", @@ -568,7 +560,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_run_enter(&mut self, name: String) -> Result { + fn sys_run_enter(&mut self, name: String) -> Result { self.do_transition(SysRunEnter(name)) } @@ -582,7 +574,7 @@ impl super::VM for CoreVM { &mut self, value: RunExitResult, retry_policy: RetryPolicy, - ) -> Result { + ) -> Result { self.do_transition(SysRunExit(value, retry_policy)) } @@ -592,7 +584,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), VMError> { + fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> { self.do_transition(SysNonCompletableEntry( "SysWriteOutput", OutputEntryMessage { @@ -611,7 +603,7 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_end(&mut self) -> Result<(), VMError> { + fn sys_end(&mut self) -> Result<(), Error> { self.do_transition(SysEnd) } diff --git a/src/vm/transitions/async_results.rs b/src/vm/transitions/async_results.rs index 66bef1f..5777d8e 100644 --- a/src/vm/transitions/async_results.rs +++ b/src/vm/transitions/async_results.rs @@ -4,13 +4,13 @@ use crate::vm::errors::{ }; use crate::vm::transitions::{HitSuspensionPoint, Transition, TransitionAndReturn}; use crate::vm::State; -use crate::{SuspendedError, VMError, Value}; +use crate::{Error, SuspendedError, Value}; use tracing::warn; pub(crate) struct NotifyInputClosed; impl Transition for State { - fn transition(self, context: &mut Context, _: NotifyInputClosed) -> Result { + fn transition(self, context: &mut Context, _: NotifyInputClosed) -> Result { match self { State::Replaying { current_await_point: Some(await_point), @@ -39,7 +39,7 @@ impl Transition for State { mut self, context: &mut Context, NotifyAwaitPoint(await_point): NotifyAwaitPoint, - ) -> Result { + ) -> Result { match self { State::Replaying { ref mut current_await_point, @@ -92,7 +92,7 @@ impl TransitionAndReturn for State { mut self, _: &mut Context, TakeAsyncResult(async_result): TakeAsyncResult, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { match self { State::Processing { ref mut current_await_point, diff --git a/src/vm/transitions/combinators.rs b/src/vm/transitions/combinators.rs index 9dcad9c..ac6676a 100644 --- a/src/vm/transitions/combinators.rs +++ b/src/vm/transitions/combinators.rs @@ -4,7 +4,7 @@ use crate::vm::errors::{UnexpectedStateError, BAD_COMBINATOR_ENTRY}; use crate::vm::transitions::{PopJournalEntry, TransitionAndReturn}; use crate::vm::State; use crate::{ - AsyncResultAccessTracker, AsyncResultCombinator, AsyncResultHandle, AsyncResultState, VMError, + AsyncResultAccessTracker, AsyncResultCombinator, AsyncResultHandle, AsyncResultState, Error, Value, }; use std::collections::HashMap; @@ -67,7 +67,7 @@ where mut self, context: &mut Context, SysTryCompleteCombinator(combinator): SysTryCompleteCombinator, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { self.check_side_effect_guard()?; match self { State::Processing { diff --git a/src/vm/transitions/input.rs b/src/vm/transitions/input.rs index e5f8dad..ed31a9e 100644 --- a/src/vm/transitions/input.rs +++ b/src/vm/transitions/input.rs @@ -4,18 +4,14 @@ use crate::vm::context::{Context, EagerState, StartInfo}; use crate::vm::errors::{BadEagerStateKeyError, KNOWN_ENTRIES_IS_ZERO, UNEXPECTED_INPUT_MESSAGE}; use crate::vm::transitions::Transition; use crate::vm::{errors, State}; -use crate::VMError; +use crate::Error; use bytes::Bytes; use tracing::debug; pub(crate) struct NewMessage(pub(crate) RawMessage); impl Transition for State { - fn transition( - self, - context: &mut Context, - NewMessage(msg): NewMessage, - ) -> Result { + fn transition(self, context: &mut Context, NewMessage(msg): NewMessage) -> Result { match msg.ty() { MessageType::Start => { self.transition(context, NewStartMessage(msg.decode_to::()?)) @@ -41,7 +37,7 @@ impl Transition for State { self, context: &mut Context, NewStartMessage(msg): NewStartMessage, - ) -> Result { + ) -> Result { context.start_info = Some(StartInfo { id: msg.id, debug_id: msg.debug_id, @@ -83,7 +79,7 @@ impl Transition for State { mut self, _: &mut Context, NewCompletionMessage(msg): NewCompletionMessage, - ) -> Result { + ) -> Result { // Add completion to completions buffer let CompletionMessage { entry_index, @@ -124,7 +120,7 @@ impl Transition for State { mut self, _: &mut Context, NewEntryAckMessage(msg): NewEntryAckMessage, - ) -> Result { + ) -> Result { match self { State::WaitingReplayEntries { ref mut async_results, @@ -156,7 +152,7 @@ impl Transition for State { self, context: &mut Context, NewEntryMessage(msg): NewEntryMessage, - ) -> Result { + ) -> Result { match self { State::WaitingReplayEntries { mut entries, diff --git a/src/vm/transitions/journal.rs b/src/vm/transitions/journal.rs index e3f25b3..9af29fe 100644 --- a/src/vm/transitions/journal.rs +++ b/src/vm/transitions/journal.rs @@ -12,13 +12,13 @@ use crate::vm::errors::{ use crate::vm::transitions::{Transition, TransitionAndReturn}; use crate::vm::State; use crate::{ - AsyncResultHandle, Header, Input, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, - VMError, + AsyncResultHandle, Error, Header, Input, NonEmptyValue, RetryPolicy, RunEnterResult, + RunExitResult, TerminalFailure, }; use std::{fmt, mem}; impl State { - pub(crate) fn check_side_effect_guard(&self) -> Result<(), VMError> { + pub(crate) fn check_side_effect_guard(&self) -> Result<(), Error> { if let State::Processing { run_state, .. } = self { if run_state.is_running() { return Err(INSIDE_RUN); @@ -39,7 +39,7 @@ impl self, context: &mut Context, PopJournalEntry(sys_name, expected): PopJournalEntry, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { match self { State::Replaying { mut entries, @@ -85,7 +85,7 @@ impl, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { match self { State::Processing { .. } => { context.output.send(&expected); @@ -105,7 +105,7 @@ impl TransitionAndReturn for State { self, context: &mut Context, _: SysInput, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { context.journal.transition(&InputEntryMessage::default()); self.check_side_effect_guard()?; let (s, msg) = TransitionAndReturn::transition_and_return( @@ -163,7 +163,7 @@ impl, - ) -> Result { + ) -> Result { context.journal.transition(&expected); self.check_side_effect_guard()?; let (s, _) = @@ -189,7 +189,7 @@ impl< self, context: &mut Context, SysCompletableEntry(sys_name, expected): SysCompletableEntry, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { context.journal.transition(&expected); self.check_side_effect_guard()?; let (mut s, actual) = TransitionAndReturn::transition_and_return( @@ -232,7 +232,7 @@ impl TransitionAndReturn for State { mut self, context: &mut Context, SysRunEnter(name): SysRunEnter, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { let expected = RunEntryMessage { name: name.clone(), ..RunEntryMessage::default() @@ -271,7 +271,7 @@ impl TransitionAndReturn for State { mut self, context: &mut Context, SysRunExit(run_exit_result, retry_policy): SysRunExit, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { match self { State::Processing { ref mut async_results, @@ -290,7 +290,7 @@ impl TransitionAndReturn for State { RunExitResult::Success(s) => NonEmptyValue::Success(s), RunExitResult::TerminalFailure(f) => NonEmptyValue::Failure(f), RunExitResult::RetryableFailure { - failure, + error: failure, attempt_duration, } => { let mut retry_info = context.infer_entry_retry_info(); @@ -301,11 +301,14 @@ impl TransitionAndReturn for State { NextRetry::Retry(next_retry_interval) => { // We need to retry! context.next_retry_delay = next_retry_interval; - return Err(VMError::new(failure.code, failure.message)); + return Err(Error::new(failure.code, failure.message)); } NextRetry::DoNotRetry => { // We don't retry, but convert the retryable error to actual error - NonEmptyValue::Failure(failure) + NonEmptyValue::Failure(TerminalFailure { + code: failure.code, + message: failure.message.to_string(), + }) } } } @@ -333,7 +336,7 @@ impl TransitionAndReturn for State { fn check_entry_header_match( actual: &M, expected: &M, -) -> Result<(), VMError> { +) -> Result<(), Error> { if !actual.header_eq(expected) { return Err(EntryMismatchError::new(actual.clone(), expected.clone()).into()); } diff --git a/src/vm/transitions/mod.rs b/src/vm/transitions/mod.rs index 5338fb8..5ee57de 100644 --- a/src/vm/transitions/mod.rs +++ b/src/vm/transitions/mod.rs @@ -7,7 +7,7 @@ mod terminal; use crate::service_protocol::messages::ErrorMessage; use crate::vm::context::Context; use crate::vm::State; -use crate::{CoreVM, VMError}; +use crate::{CoreVM, Error}; pub(crate) use async_results::*; pub(crate) use combinators::*; pub(crate) use input::*; @@ -19,7 +19,7 @@ trait Transition where Self: Sized, { - fn transition(self, context: &mut CTX, event: E) -> Result; + fn transition(self, context: &mut CTX, event: E) -> Result; } pub(crate) trait TransitionAndReturn @@ -31,7 +31,7 @@ where self, context: &mut CTX, event: E, - ) -> Result<(Self, Self::Output), VMError>; + ) -> Result<(Self, Self::Output), Error>; } impl TransitionAndReturn for STATE @@ -44,13 +44,13 @@ where self, context: &mut CTX, event: E, - ) -> Result<(Self, Self::Output), VMError> { + ) -> Result<(Self, Self::Output), Error> { Transition::transition(self, context, event).map(|s| (s, ())) } } impl CoreVM { - pub(super) fn do_transition(&mut self, event: E) -> Result + pub(super) fn do_transition(&mut self, event: E) -> Result where State: TransitionAndReturn, { diff --git a/src/vm/transitions/terminal.rs b/src/vm/transitions/terminal.rs index 0e2cb55..f1972a9 100644 --- a/src/vm/transitions/terminal.rs +++ b/src/vm/transitions/terminal.rs @@ -3,11 +3,11 @@ use crate::vm::context::Context; use crate::vm::errors::UnexpectedStateError; use crate::vm::transitions::Transition; use crate::vm::State; -use crate::VMError; +use crate::Error; use std::time::Duration; pub(crate) struct HitError { - pub(crate) error: VMError, + pub(crate) error: Error, pub(crate) next_retry_delay: Option, } @@ -19,7 +19,7 @@ impl Transition for State { error, next_retry_delay, }: HitError, - ) -> Result { + ) -> Result { ctx.next_retry_delay = next_retry_delay; // We let CoreVM::do_transition handle this @@ -34,7 +34,7 @@ impl Transition for State { self, context: &mut Context, HitSuspensionPoint(await_point): HitSuspensionPoint, - ) -> Result { + ) -> Result { if matches!(self, State::Suspended | State::Ended) { // Nothing to do return Ok(self); @@ -51,7 +51,7 @@ impl Transition for State { pub(crate) struct SysEnd; impl Transition for State { - fn transition(self, context: &mut Context, _: SysEnd) -> Result { + fn transition(self, context: &mut Context, _: SysEnd) -> Result { match self { State::Processing { .. } => { context.output.send(&EndMessage {});