Skip to content

Commit

Permalink
Min max protocol version + better errors (#18)
Browse files Browse the repository at this point in the history
* Expose min max version

* Improve the error situation
  • Loading branch information
slinkydeveloper authored Oct 1, 2024
1 parent fe0c80d commit 771df58
Show file tree
Hide file tree
Showing 18 changed files with 186 additions and 174 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ request_identity = ["dep:ring", "dep:sha2", "dep:jsonwebtoken", "dep:bs58"]
sha2_random_seed = ["dep:sha2"]

[dependencies]
thiserror = "1.0.61"
thiserror = "1.0.64"
prost = "0.13.2"
bytes = "1.6"
bytes-utils = "0.1.4"
Expand Down
68 changes: 50 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ use std::fmt;
use std::time::Duration;

pub use crate::retries::RetryPolicy;
use crate::vm::AsyncResultAccessTrackerInner;
pub use headers::HeaderMap;
#[cfg(feature = "request_identity")]
pub use request_identity::*;
pub use service_protocol::Version;
pub use vm::CoreVM;

// Re-export only some stuff from vm::errors
pub mod error {
pub use crate::vm::errors::codes;
pub use crate::vm::errors::InvocationErrorCode;
}

use crate::vm::AsyncResultAccessTrackerInner;

#[derive(Debug, Eq, PartialEq)]
pub struct Header {
pub key: Cow<'static, str>,
Expand All @@ -27,6 +35,7 @@ pub struct Header {
pub struct ResponseHead {
pub status_code: u16,
pub headers: Vec<Header>,
pub version: Version,
}

#[derive(Debug, Clone, Copy, thiserror::Error)]
Expand All @@ -35,21 +44,25 @@ pub struct SuspendedError;

#[derive(Debug, Clone, thiserror::Error)]
#[error("VM Error [{code}]: {message}. Description: {description}")]
pub struct VMError {
pub struct Error {
code: u16,
message: Cow<'static, str>,
description: Cow<'static, str>,
}

impl VMError {
impl Error {
pub fn new(code: impl Into<u16>, message: impl Into<Cow<'static, str>>) -> Self {
VMError {
Error {
code: code.into(),
message: message.into(),
description: Default::default(),
}
}

pub fn internal(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(error::codes::INTERNAL, message)
}

pub fn code(&self) -> u16 {
self.code
}
Expand All @@ -61,14 +74,38 @@ impl VMError {
pub fn description(&self) -> &str {
&self.description
}

pub fn with_description(mut self, description: impl Into<Cow<'static, str>>) -> Self {
self.description = description.into();
self
}

/// Append the given description to the original one, in case the code is the same
pub fn append_description_for_code(
mut self,
code: impl Into<u16>,
description: impl Into<Cow<'static, str>>,
) -> Self {
let c = code.into();
if self.code == c {
if self.description.is_empty() {
self.description = description.into();
} else {
self.description = format!("{}. {}", self.description, description.into()).into();
}
self
} else {
self
}
}
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum SuspendedOrVMError {
#[error(transparent)]
Suspended(SuspendedError),
#[error(transparent)]
VM(VMError),
VM(Error),
}

#[derive(Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -107,15 +144,15 @@ pub enum Value {
/// a void/None/undefined success
Void,
Success(Bytes),
Failure(Failure),
Failure(TerminalFailure),
/// Only returned for get_state_keys
StateKeys(Vec<String>),
CombinatorResult(Vec<AsyncResultHandle>),
}

/// Terminal failure
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Failure {
pub struct TerminalFailure {
pub code: u16,
pub message: String,
}
Expand All @@ -137,17 +174,17 @@ pub enum RunEnterResult {
#[derive(Debug, Clone)]
pub enum RunExitResult {
Success(Bytes),
TerminalFailure(Failure),
TerminalFailure(TerminalFailure),
RetryableFailure {
attempt_duration: Duration,
failure: Failure,
error: Error,
},
}

#[derive(Debug, Clone)]
pub enum NonEmptyValue {
Success(Bytes),
Failure(Failure),
Failure(TerminalFailure),
}

impl From<NonEmptyValue> for Value {
Expand All @@ -165,7 +202,7 @@ pub enum TakeOutputResult {
EOF,
}

pub type VMResult<T> = Result<T, VMError>;
pub type VMResult<T> = Result<T, Error>;

pub struct VMOptions {
/// If true, false when two concurrent async results are awaited at the same time. If false, just log it.
Expand Down Expand Up @@ -193,20 +230,15 @@ pub trait VM: Sized {

// --- Errors

fn notify_error(
&mut self,
message: Cow<'static, str>,
description: Cow<'static, str>,
next_retry_delay: Option<Duration>,
);
fn notify_error(&mut self, error: Error, next_retry_delay: Option<Duration>);

// --- Output stream

fn take_output(&mut self) -> TakeOutputResult;

// --- Execution start waiting point

fn is_ready_to_execute(&self) -> Result<bool, VMError>;
fn is_ready_to_execute(&self) -> VMResult<bool>;

// --- Async results

Expand Down
16 changes: 8 additions & 8 deletions src/service_protocol/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ impl Encoder {
pub fn new(service_protocol_version: Version) -> Self {
assert_eq!(
service_protocol_version,
Version::latest(),
Version::maximum_supported_version(),
"Encoder only supports service protocol version {:?}",
Version::latest()
Version::maximum_supported_version()
);
Self {}
}
Expand Down Expand Up @@ -107,9 +107,9 @@ impl Decoder {
pub fn new(service_protocol_version: Version) -> Self {
assert_eq!(
service_protocol_version,
Version::latest(),
Version::maximum_supported_version(),
"Decoder only supports service protocol version {:?}",
Version::latest()
Version::maximum_supported_version()
);
Self {
buf: SegmentedBuf::new(),
Expand Down Expand Up @@ -185,8 +185,8 @@ mod tests {

#[test]
fn fill_decoder_with_several_messages() {
let encoder = Encoder::new(Version::latest());
let mut decoder = Decoder::new(Version::latest());
let encoder = Encoder::new(Version::maximum_supported_version());
let mut decoder = Decoder::new(Version::maximum_supported_version());

let expected_msg_0 = messages::StartMessage {
id: "key".into(),
Expand Down Expand Up @@ -260,8 +260,8 @@ mod tests {
}

fn partial_decoding_test(split_index: usize) {
let encoder = Encoder::new(Version::latest());
let mut decoder = Decoder::new(Version::latest());
let encoder = Encoder::new(Version::maximum_supported_version());
let mut decoder = Decoder::new(Version::maximum_supported_version());

let expected_msg = messages::InputEntryMessage {
value: Bytes::from_static("input".as_bytes()),
Expand Down
32 changes: 16 additions & 16 deletions src/service_protocol/messages.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::service_protocol::messages::get_state_keys_entry_message::StateKeys;
use crate::service_protocol::{MessageHeader, MessageType};
use crate::vm::errors::{DecodeStateKeysProst, DecodeStateKeysUtf8, EmptyStateKeys};
use crate::{NonEmptyValue, VMError, Value};
use crate::{Error, NonEmptyValue, Value};
use paste::paste;
use prost::Message;

Expand All @@ -25,7 +25,7 @@ pub trait EntryMessageHeaderEq {

pub trait CompletableEntryMessage: RestateMessage + EntryMessage + EntryMessageHeaderEq {
fn is_completed(&self) -> bool;
fn into_completion(self) -> Result<Option<Value>, VMError>;
fn into_completion(self) -> Result<Option<Value>, Error>;
fn completion_parsing_hint() -> CompletionParsingHint;
}

Expand Down Expand Up @@ -74,7 +74,7 @@ macro_rules! impl_message_traits {
self.result.is_some()
}

fn into_completion(self) -> Result<Option<Value>, VMError> {
fn into_completion(self) -> Result<Option<Value>, Error> {
self.result.map(TryInto::try_into).transpose()
}

Expand Down Expand Up @@ -133,7 +133,7 @@ impl CompletableEntryMessage for GetStateKeysEntryMessage {
self.result.is_some()
}

fn into_completion(self) -> Result<Option<Value>, VMError> {
fn into_completion(self) -> Result<Option<Value>, Error> {
self.result.map(TryInto::try_into).transpose()
}

Expand Down Expand Up @@ -255,7 +255,7 @@ impl EntryMessageHeaderEq for CombinatorEntryMessage {
// --- Completion extraction

impl TryFrom<get_state_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: get_state_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -267,7 +267,7 @@ impl TryFrom<get_state_entry_message::Result> for Value {
}

impl TryFrom<get_state_keys_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: get_state_keys_entry_message::Result) -> Result<Self, Self::Error> {
match value {
Expand All @@ -286,7 +286,7 @@ impl TryFrom<get_state_keys_entry_message::Result> for Value {
}

impl TryFrom<sleep_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: sleep_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -297,7 +297,7 @@ impl TryFrom<sleep_entry_message::Result> for Value {
}

impl TryFrom<call_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: call_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -308,7 +308,7 @@ impl TryFrom<call_entry_message::Result> for Value {
}

impl TryFrom<awakeable_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: awakeable_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -319,7 +319,7 @@ impl TryFrom<awakeable_entry_message::Result> for Value {
}

impl TryFrom<get_promise_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: get_promise_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -330,7 +330,7 @@ impl TryFrom<get_promise_entry_message::Result> for Value {
}

impl TryFrom<peek_promise_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: peek_promise_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -342,7 +342,7 @@ impl TryFrom<peek_promise_entry_message::Result> for Value {
}

impl TryFrom<complete_promise_entry_message::Result> for Value {
type Error = VMError;
type Error = Error;

fn try_from(value: complete_promise_entry_message::Result) -> Result<Self, Self::Error> {
Ok(match value {
Expand All @@ -363,16 +363,16 @@ impl From<run_entry_message::Result> for NonEmptyValue {

// --- Other conversions

impl From<crate::Failure> for Failure {
fn from(value: crate::Failure) -> Self {
impl From<crate::TerminalFailure> for Failure {
fn from(value: crate::TerminalFailure) -> Self {
Self {
code: value.code as u32,
message: value.message,
}
}
}

impl From<Failure> for crate::Failure {
impl From<Failure> for crate::TerminalFailure {
fn from(value: Failure) -> Self {
Self {
code: value.code as u16,
Expand All @@ -391,7 +391,7 @@ pub(crate) enum CompletionParsingHint {
}

impl CompletionParsingHint {
pub(crate) fn parse(self, result: completion_message::Result) -> Result<Value, VMError> {
pub(crate) fn parse(self, result: completion_message::Result) -> Result<Value, Error> {
match self {
CompletionParsingHint::StateKeys => match result {
completion_message::Result::Empty(_) => Err(EmptyStateKeys.into()),
Expand Down
10 changes: 7 additions & 3 deletions src/service_protocol/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::str::FromStr;

#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq)]
pub enum Version {
V1,
V2,
V1 = 1,
V2 = 2,
}

const CONTENT_TYPE_V1: &str = "application/vnd.restate.invocation.v1";
Expand All @@ -18,7 +18,11 @@ impl Version {
}
}

pub const fn latest() -> Self {
pub const fn minimum_supported_version() -> Self {
Version::V2
}

pub const fn maximum_supported_version() -> Self {
Version::V2
}
}
Expand Down
Loading

0 comments on commit 771df58

Please sign in to comment.