diff --git a/Cargo.toml b/Cargo.toml index da61fa7..443e66f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,4 +27,4 @@ http = { version = "1.1.0", optional = true } googletest = "0.11.0" test-log = { version = "0.2.16", default-features = false, features = ["trace", "color"] } assert2 = "0.3.14" -prost-build = "0.13.2" +prost-build = "=0.13.3" diff --git a/README.md b/README.md index 603a575..3c98014 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Shared core to build SDKs in various languages. Currently used by: * [Python SDK](https://github.com/restatedev/sdk-python) +* [Rust SDK](https://github.com/restatedev/sdk-rust) ## Versions diff --git a/service-protocol-ext/combinators.proto b/service-protocol-ext/combinators.proto new file mode 100644 index 0000000..01df28f --- /dev/null +++ b/service-protocol-ext/combinators.proto @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +syntax = "proto3"; + +package dev.restate.service.protocol.extensions; + +// Type: 0xFC00 + 2 +message CombinatorEntryMessage { + repeated uint32 completed_entries_order = 1; + + // Entry name + string name = 12; +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 6279819..1b647d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,9 +6,11 @@ mod vm; use bytes::Bytes; use std::borrow::Cow; +use std::fmt; use std::time::Duration; pub use crate::retries::RetryPolicy; +use crate::vm::AsyncResultAccessTrackerInner; pub use headers::HeaderMap; pub use request_identity::*; pub use vm::CoreVM; @@ -83,7 +85,7 @@ pub struct Target { pub key: Option, } -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Hash, Clone, Copy, Eq, PartialEq)] pub struct AsyncResultHandle(u32); impl From for AsyncResultHandle { @@ -106,6 +108,7 @@ pub enum Value { Failure(Failure), /// Only returned for get_state_keys StateKeys(Vec), + CombinatorResult(Vec), } /// Terminal failure @@ -162,8 +165,21 @@ pub enum TakeOutputResult { pub type VMResult = Result; +pub struct VMOptions { + /// If true, false when two concurrent async results are awaited at the same time. If false, just log it. + pub fail_on_wait_concurrent_async_result: bool, +} + +impl Default for VMOptions { + fn default() -> Self { + Self { + fail_on_wait_concurrent_async_result: true, + } + } +} + pub trait VM: Sized { - fn new(request_headers: impl HeaderMap) -> VMResult; + fn new(request_headers: impl HeaderMap, options: VMOptions) -> VMResult; fn get_response_head(&self) -> ResponseHead; @@ -257,6 +273,12 @@ pub trait VM: Sized { /// Returns true if the state machine is between a sys_run_enter and sys_run_exit fn is_inside_run(&self) -> bool; + + /// Returns false if the combinator can't be completed yet. + fn sys_try_complete_combinator( + &mut self, + combinator: impl AsyncResultCombinator + fmt::Debug, + ) -> VMResult>; } // HOW TO USE THIS API @@ -300,5 +322,27 @@ pub trait VM: Sized { // } // io.close() +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AsyncResultState { + Success, + Failure, + NotReady, +} + +pub struct AsyncResultAccessTracker(AsyncResultAccessTrackerInner); + +impl AsyncResultAccessTracker { + pub fn get_state(&mut self, handle: AsyncResultHandle) -> AsyncResultState { + self.0.get_state(handle) + } +} + +pub trait AsyncResultCombinator { + fn try_complete( + &self, + tracker: &mut AsyncResultAccessTracker, + ) -> Option>; +} + #[cfg(test)] mod tests; diff --git a/src/service_protocol/generated/dev.restate.service.protocol.extensions.rs b/src/service_protocol/generated/dev.restate.service.protocol.extensions.rs new file mode 100644 index 0000000..c347398 --- /dev/null +++ b/src/service_protocol/generated/dev.restate.service.protocol.extensions.rs @@ -0,0 +1,10 @@ +// This file is @generated by prost-build. +/// Type: 0xFC00 + 2 +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CombinatorEntryMessage { + #[prost(uint32, repeated, tag = "1")] + pub completed_entries_order: ::prost::alloc::vec::Vec, + /// Entry name + #[prost(string, tag = "12")] + pub name: ::prost::alloc::string::String, +} diff --git a/src/service_protocol/generated/dev.restate.service.protocol.rs b/src/service_protocol/generated/dev.restate.service.protocol.rs index b6f9c69..b16279d 100644 --- a/src/service_protocol/generated/dev.restate.service.protocol.rs +++ b/src/service_protocol/generated/dev.restate.service.protocol.rs @@ -509,9 +509,9 @@ impl ServiceProtocolVersion { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - ServiceProtocolVersion::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED", - ServiceProtocolVersion::V1 => "V1", - ServiceProtocolVersion::V2 => "V2", + Self::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED", + Self::V1 => "V1", + Self::V2 => "V2", } } /// Creates an enum from field names used in the ProtoBuf definition. diff --git a/src/service_protocol/header.rs b/src/service_protocol/header.rs index f5f94a7..7d00ab0 100644 --- a/src/service_protocol/header.rs +++ b/src/service_protocol/header.rs @@ -47,6 +47,7 @@ pub enum MessageType { GetPromiseEntry, PeekPromiseEntry, CompletePromiseEntry, + CombinatorEntry, CustomEntry(u16), } @@ -75,6 +76,7 @@ impl MessageType { MessageType::GetPromiseEntry => MessageKind::State, MessageType::PeekPromiseEntry => MessageKind::State, MessageType::CompletePromiseEntry => MessageKind::State, + MessageType::CombinatorEntry => MessageKind::Syscall, MessageType::CustomEntry(_) => MessageKind::CustomEntry, } } @@ -127,6 +129,7 @@ const BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C02; const AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C03; const COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C04; const SIDE_EFFECT_ENTRY_MESSAGE_TYPE: u16 = 0x0C05; +const COMBINATOR_ENTRY_MESSAGE_TYPE: u16 = 0xFC02; impl From for MessageTypeId { fn from(mt: MessageType) -> Self { @@ -153,6 +156,7 @@ impl From for MessageTypeId { MessageType::GetPromiseEntry => GET_PROMISE_ENTRY_MESSAGE_TYPE, MessageType::PeekPromiseEntry => PEEK_PROMISE_ENTRY_MESSAGE_TYPE, MessageType::CompletePromiseEntry => COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE, + MessageType::CombinatorEntry => COMBINATOR_ENTRY_MESSAGE_TYPE, MessageType::CustomEntry(id) => id, } } @@ -189,6 +193,7 @@ impl TryFrom for MessageType { PEEK_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::PeekPromiseEntry), COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompletePromiseEntry), SIDE_EFFECT_ENTRY_MESSAGE_TYPE => Ok(MessageType::RunEntry), + COMBINATOR_ENTRY_MESSAGE_TYPE => Ok(MessageType::CombinatorEntry), v if ((v & CUSTOM_MESSAGE_MASK) != 0) => Ok(MessageType::CustomEntry(v)), v => Err(UnknownMessageType(v)), } diff --git a/src/service_protocol/messages.rs b/src/service_protocol/messages.rs index 9227429..37eef62 100644 --- a/src/service_protocol/messages.rs +++ b/src/service_protocol/messages.rs @@ -40,6 +40,7 @@ impl WriteableRestateMessage for M { } include!("./generated/dev.restate.service.protocol.rs"); +include!("./generated/dev.restate.service.protocol.extensions.rs"); macro_rules! impl_message_traits { ($name:ident: core) => { @@ -233,6 +234,24 @@ impl EntryMessageHeaderEq for RunEntryMessage { } } +impl_message_traits!(CombinatorEntry: message); +impl_message_traits!(CombinatorEntry: entry); +impl WriteableRestateMessage for CombinatorEntryMessage { + fn generate_header(&self, never_ack: bool) -> MessageHeader { + MessageHeader::new_ackable_entry_header( + MessageType::CombinatorEntry, + None, + if never_ack { Some(false) } else { Some(true) }, + self.encoded_len() as u32, + ) + } +} +impl EntryMessageHeaderEq for CombinatorEntryMessage { + fn header_eq(&self, _: &Self) -> bool { + true + } +} + // --- Completion extraction impl TryFrom for Value { diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 1746ca9..3e1bc96 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -24,7 +24,11 @@ use test_log::test; impl CoreVM { fn mock_init(version: Version) -> CoreVM { - let vm = CoreVM::new(vec![("content-type".to_owned(), version.to_string())]).unwrap(); + let vm = CoreVM::new( + vec![("content-type".to_owned(), version.to_string())], + VMOptions::default(), + ) + .unwrap(); assert_that!( vm.get_response_head().headers, diff --git a/src/tests/state.rs b/src/tests/state.rs index 4774153..d62911d 100644 --- a/src/tests/state.rs +++ b/src/tests/state.rs @@ -25,7 +25,7 @@ fn get_state_handler(vm: &mut CoreVM) { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice( @@ -374,7 +374,7 @@ mod eager { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice( @@ -619,7 +619,7 @@ mod eager { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_state_set( @@ -644,7 +644,7 @@ mod eager { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_write_output(NonEmptyValue::Success(second_get_result)) @@ -799,7 +799,7 @@ mod eager { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_state_clear("STATE".to_owned()).unwrap(); @@ -958,7 +958,7 @@ mod eager { vm.sys_end().unwrap(); return; } - Value::StateKeys(_) => panic!("Unexpected variant"), + _ => panic!("Unexpected variant"), }; vm.sys_state_clear_all().unwrap(); @@ -1232,9 +1232,9 @@ mod state_keys { } let output = match h1_result.unwrap().unwrap() { - Value::Void | Value::Success(_) => panic!("Unexpected variants"), Value::Failure(f) => NonEmptyValue::Failure(f), Value::StateKeys(keys) => NonEmptyValue::Success(Bytes::from(keys.join(","))), + _ => panic!("Unexpected variants"), }; vm.sys_write_output(output).unwrap(); diff --git a/src/vm/context.rs b/src/vm/context.rs index 891bfba..43d5483 100644 --- a/src/vm/context.rs +++ b/src/vm/context.rs @@ -3,7 +3,7 @@ use crate::service_protocol::messages::{ WriteableRestateMessage, }; use crate::service_protocol::{Encoder, MessageType, Version}; -use crate::{EntryRetryInfo, VMError, Value}; +use crate::{AsyncResultHandle, AsyncResultState, EntryRetryInfo, VMError, VMOptions, Value}; use bytes::Bytes; use bytes_utils::SegmentedBuf; use std::collections::{HashMap, VecDeque}; @@ -182,6 +182,24 @@ impl AsyncResultsState { self.ready_results.insert(idx, value); } } + + pub(crate) fn get_ready_results_state(&self) -> HashMap { + self.ready_results + .iter() + .map(|(idx, val)| { + ( + AsyncResultHandle(*idx), + match val { + Value::Void + | Value::Success(_) + | Value::StateKeys(_) + | Value::CombinatorResult(_) => AsyncResultState::Success, + Value::Failure(_) => AsyncResultState::Failure, + }, + ) + }) + .collect() + } } #[derive(Debug)] @@ -286,6 +304,8 @@ pub(crate) struct Context { // Used by the error handler to set ErrorMessage.next_retry_delay pub(crate) next_retry_delay: Option, + + pub(crate) options: VMOptions, } impl Context { diff --git a/src/vm/errors.rs b/src/vm/errors.rs index 685bf71..ac64fc6 100644 --- a/src/vm/errors.rs +++ b/src/vm/errors.rs @@ -61,6 +61,7 @@ pub mod codes { pub const UNSUPPORTED_MEDIA_TYPE: InvocationErrorCode = InvocationErrorCode(415); pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570); pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571); + pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572); } // Const errors @@ -118,6 +119,11 @@ pub const INPUT_CLOSED_WHILE_WAITING_ENTRIES: VMError = VMError::new_const( "The input was closed while still waiting to receive all the `known_entries`", ); +pub const BAD_COMBINATOR_ENTRY: VMError = VMError::new_const( + codes::PROTOCOL_VIOLATION, + "The combinator cannot be replayed. This is most likely caused by non deterministic code.", +); + // Other errors #[derive(Debug, Clone, thiserror::Error)] @@ -223,7 +229,7 @@ impl WithInvocationErrorCode for DecodingError { } impl_error_code!(UnavailableEntryError, JOURNAL_MISMATCH); impl_error_code!(UnexpectedStateError, PROTOCOL_VIOLATION); -impl_error_code!(AwaitingTwoAsyncResultError, INTERNAL); +impl_error_code!(AwaitingTwoAsyncResultError, AWAITING_TWO_ASYNC_RESULTS); impl_error_code!(BadEagerStateKeyError, INTERNAL); impl_error_code!(DecodeStateKeysProst, PROTOCOL_VIOLATION); impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION); diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 0b9590f..8a36ba2 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -13,8 +13,9 @@ use crate::vm::context::{EagerGetState, EagerGetStateKeys}; use crate::vm::errors::UnexpectedStateError; use crate::vm::transitions::*; use crate::{ - AsyncResultHandle, Header, Input, NonEmptyValue, ResponseHead, RetryPolicy, RunEnterResult, - RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, VMError, VMResult, Value, + AsyncResultCombinator, AsyncResultHandle, Header, Input, NonEmptyValue, ResponseHead, + RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, + VMError, VMOptions, VMResult, Value, }; use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}; use base64::{alphabet, Engine}; @@ -32,6 +33,8 @@ mod context; pub(crate) mod errors; mod transitions; +pub(crate) use transitions::AsyncResultAccessTrackerInner; + const CONTENT_TYPE: &str = "content-type"; #[derive(Debug, IntoStaticStr)] @@ -72,6 +75,17 @@ pub struct CoreVM { last_transition: Result, } +impl CoreVM { + // Returns empty string if the invocation id is not present + fn debug_invocation_id(&self) -> &str { + if let Some(start_info) = self.context.start_info() { + &start_info.debug_id + } else { + "" + } + } +} + impl fmt::Debug for CoreVM { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut s = f.debug_struct("CoreVM"); @@ -98,7 +112,7 @@ const _: () = is_send::(); impl super::VM for CoreVM { #[instrument(level = "debug", skip_all, ret)] - fn new(request_headers: impl HeaderMap) -> Result { + fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result { let version = request_headers .extract(CONTENT_TYPE) .map_err(|e| { @@ -127,12 +141,18 @@ impl super::VM for CoreVM { journal: Default::default(), eager_state: Default::default(), next_retry_delay: None, + options, }, last_transition: Ok(State::WaitingStart), }) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn get_response_head(&self) -> ResponseHead { ResponseHead { status_code: 200, @@ -143,7 +163,12 @@ impl super::VM for CoreVM { } } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn notify_input(&mut self, buffer: Bytes) { self.decoder.push(buffer); loop { @@ -171,13 +196,23 @@ impl super::VM for CoreVM { } } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn notify_input_closed(&mut self) { self.context.input_is_closed = true; let _ = self.do_transition(NotifyInputClosed); } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn notify_error( &mut self, message: Cow<'static, str>, @@ -194,7 +229,12 @@ impl super::VM for CoreVM { }); } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn take_output(&mut self) -> TakeOutputResult { if self.context.output.buffer.has_remaining() { TakeOutputResult::Buffer( @@ -210,7 +250,12 @@ impl super::VM for CoreVM { } } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn is_ready_to_execute(&self) -> Result { match &self.last_transition { Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false), @@ -220,12 +265,22 @@ impl super::VM for CoreVM { } } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn notify_await_point(&mut self, AsyncResultHandle(await_point): AsyncResultHandle) { let _ = self.do_transition(NotifyAwaitPoint(await_point)); } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn take_async_result( &mut self, handle: AsyncResultHandle, @@ -237,12 +292,22 @@ impl super::VM for CoreVM { } } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_input(&mut self) -> Result { self.do_transition(SysInput) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_state_get(&mut self, key: String) -> Result { let result = match self.context.eager_state.get(&key) { EagerGetState::Unknown => None, @@ -259,7 +324,12 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_state_get_keys(&mut self) -> VMResult { let result = match self.context.eager_state.get_keys() { EagerGetStateKeys::Unknown => None, @@ -278,7 +348,12 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self, value), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), VMError> { self.context.eager_state.set(key.clone(), value.clone()); self.do_transition(SysNonCompletableEntry( @@ -291,7 +366,12 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_state_clear(&mut self, key: String) -> Result<(), VMError> { self.context.eager_state.clear(key.clone()); self.do_transition(SysNonCompletableEntry( @@ -303,7 +383,12 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_state_clear_all(&mut self) -> Result<(), VMError> { self.context.eager_state.clear_all(); self.do_transition(SysNonCompletableEntry( @@ -312,6 +397,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_sleep(&mut self, duration: Duration) -> VMResult { self.do_transition(SysCompletableEntry( "SysSleep", @@ -323,6 +414,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self, input), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult { self.do_transition(SysCompletableEntry( "SysCall", @@ -336,6 +433,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self, input), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_send(&mut self, target: Target, input: Bytes, delay: Option) -> VMResult<()> { self.do_transition(SysNonCompletableEntry( "SysOneWayCall", @@ -355,6 +458,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_awakeable(&mut self) -> VMResult<(String, AsyncResultHandle)> { self.do_transition(SysCompletableEntry( "SysAwakeable", @@ -371,6 +480,12 @@ impl super::VM for CoreVM { }) } + #[instrument( + level = "debug", + skip(self, value), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> { self.do_transition(SysNonCompletableEntry( "SysCompleteAwakeable", @@ -387,6 +502,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_get_promise(&mut self, key: String) -> VMResult { self.do_transition(SysCompletableEntry( "SysGetPromise", @@ -397,6 +518,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_peek_promise(&mut self, key: String) -> VMResult { self.do_transition(SysCompletableEntry( "SysPeekPromise", @@ -407,6 +534,12 @@ impl super::VM for CoreVM { )) } + #[instrument( + level = "debug", + skip(self, value), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_complete_promise( &mut self, key: String, @@ -429,12 +562,22 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_run_enter(&mut self, name: String) -> Result { self.do_transition(SysRunEnter(name)) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self, value, retry_policy), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_run_exit( &mut self, value: RunExitResult, @@ -443,7 +586,12 @@ impl super::VM for CoreVM { self.do_transition(SysRunExit(value, retry_policy)) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self, value), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), VMError> { self.do_transition(SysNonCompletableEntry( "SysWriteOutput", @@ -457,7 +605,12 @@ impl super::VM for CoreVM { )) } - #[instrument(level = "debug", ret)] + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] fn sys_end(&mut self) -> Result<(), VMError> { self.do_transition(SysEnd) } @@ -475,6 +628,19 @@ impl super::VM for CoreVM { }) ) } + + #[instrument( + level = "debug", + skip(self), + fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), + ret + )] + fn sys_try_complete_combinator( + &mut self, + combinator: impl AsyncResultCombinator + fmt::Debug, + ) -> VMResult> { + self.do_transition(SysTryCompleteCombinator(combinator)) + } } const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new() diff --git a/src/vm/transitions/async_results.rs b/src/vm/transitions/async_results.rs index 08adaa7..66bef1f 100644 --- a/src/vm/transitions/async_results.rs +++ b/src/vm/transitions/async_results.rs @@ -5,6 +5,7 @@ use crate::vm::errors::{ use crate::vm::transitions::{HitSuspensionPoint, Transition, TransitionAndReturn}; use crate::vm::State; use crate::{SuspendedError, VMError, Value}; +use tracing::warn; pub(crate) struct NotifyInputClosed; @@ -52,11 +53,21 @@ impl Transition for State { } => { if let Some(previous) = current_await_point { if *previous != await_point { - return Err(AwaitingTwoAsyncResultError { - previous: *previous, - current: await_point, + if context.options.fail_on_wait_concurrent_async_result { + return Err(AwaitingTwoAsyncResultError { + previous: *previous, + current: await_point, + } + .into()); + } else { + warn!( + "{}", + AwaitingTwoAsyncResultError { + previous: *previous, + current: await_point, + } + ) } - .into()); } } if context.input_is_closed && !async_results.has_ready_result(await_point) { diff --git a/src/vm/transitions/combinators.rs b/src/vm/transitions/combinators.rs new file mode 100644 index 0000000..9dcad9c --- /dev/null +++ b/src/vm/transitions/combinators.rs @@ -0,0 +1,204 @@ +use crate::service_protocol::messages::{CombinatorEntryMessage, SuspensionMessage}; +use crate::vm::context::Context; +use crate::vm::errors::{UnexpectedStateError, BAD_COMBINATOR_ENTRY}; +use crate::vm::transitions::{PopJournalEntry, TransitionAndReturn}; +use crate::vm::State; +use crate::{ + AsyncResultAccessTracker, AsyncResultCombinator, AsyncResultHandle, AsyncResultState, VMError, + Value, +}; +use std::collections::HashMap; +use std::iter::Peekable; +use std::vec::IntoIter; + +pub(crate) enum AsyncResultAccessTrackerInner { + Processing { + known_results: HashMap, + tracked_access_to_completed_results: Vec, + tracked_access_to_uncompleted_results: Vec, + }, + Replaying { + replay_combinators: Peekable>, + }, +} + +impl AsyncResultAccessTrackerInner { + pub fn get_state(&mut self, handle: AsyncResultHandle) -> AsyncResultState { + match self { + AsyncResultAccessTrackerInner::Processing { + known_results, + tracked_access_to_completed_results, + tracked_access_to_uncompleted_results, + } => { + // Record if a known result is available + if let Some(res) = known_results.get(&handle) { + tracked_access_to_completed_results.push(handle); + *res + } else { + tracked_access_to_uncompleted_results.push(handle); + AsyncResultState::NotReady + } + } + AsyncResultAccessTrackerInner::Replaying { + replay_combinators: replay_status, + } => { + if let Some((_, result)) = + replay_status.next_if(|(peeked_handle, _)| *peeked_handle == handle) + { + result + } else { + // It's a not completed handle! + AsyncResultState::NotReady + } + } + } + } +} + +pub(crate) struct SysTryCompleteCombinator(pub(crate) C); + +impl TransitionAndReturn> for State +where + C: AsyncResultCombinator, +{ + type Output = Option; + + fn transition_and_return( + mut self, + context: &mut Context, + SysTryCompleteCombinator(combinator): SysTryCompleteCombinator, + ) -> Result<(Self, Self::Output), VMError> { + self.check_side_effect_guard()?; + match self { + State::Processing { + ref mut async_results, + .. + } => { + // Try complete the combinator + let mut async_result_tracker = + AsyncResultAccessTracker(AsyncResultAccessTrackerInner::Processing { + known_results: async_results.get_ready_results_state(), + tracked_access_to_completed_results: vec![], + tracked_access_to_uncompleted_results: vec![], + }); + + if let Some(combinator_result) = combinator.try_complete(&mut async_result_tracker) + { + // --- Combinator is ready! + + // Prepare the message to write out + let completed_entries_order = match async_result_tracker.0 { + AsyncResultAccessTrackerInner::Processing { + tracked_access_to_completed_results, + .. + } => tracked_access_to_completed_results, + _ => unreachable!(), + }; + let message = CombinatorEntryMessage { + completed_entries_order: completed_entries_order + .into_iter() + .map(Into::into) + .collect(), + ..CombinatorEntryMessage::default() + }; + + // Let's execute the transition + context.journal.transition(&message); + let current_journal_index = context.journal.expect_index(); + + // Cache locally the Combinator result, the user will be able to access this once the ack is received. + async_results.insert_waiting_ack_result( + current_journal_index, + Value::CombinatorResult(combinator_result), + ); + + // Write out the combinator message + context.output.send(&message); + + Ok((self, Some(AsyncResultHandle(current_journal_index)))) + } else { + // --- The combinator is not ready yet! Let's wait for more completions to come. + + if context.input_is_closed { + let uncompleted_entries_order = match async_result_tracker.0 { + AsyncResultAccessTrackerInner::Processing { + tracked_access_to_uncompleted_results, + .. + } => tracked_access_to_uncompleted_results, + _ => unreachable!(), + }; + + // We can't do progress anymore, let's suspend + context.output.send(&SuspensionMessage { + entry_indexes: uncompleted_entries_order + .into_iter() + .map(Into::into) + .collect(), + }); + context.output.send_eof(); + + Ok((State::Suspended, None)) + } else { + Ok((self, None)) + } + } + } + s => { + let expected = CombinatorEntryMessage::default(); + + // We increment the index now only if we're not processing. + context.journal.transition(&expected); + let current_journal_index = context.journal.expect_index(); + + // We should get the combinator message now + let (mut s, msg) = s.transition_and_return( + context, + PopJournalEntry("SysTryCompleteCombinator", expected), + )?; + + match s { + State::Replaying { + ref mut async_results, + .. + } + | State::Processing { + ref mut async_results, + .. + } => { + let ar_states = async_results.get_ready_results_state(); + + // Compute the replay_combinators + let mut replay_combinators = + Vec::with_capacity(msg.completed_entries_order.capacity()); + for idx in msg.completed_entries_order { + let handle = AsyncResultHandle(idx); + let async_result_state = + ar_states.get(&handle).ok_or(BAD_COMBINATOR_ENTRY)?; + replay_combinators.push((handle, *async_result_state)); + } + + // Replay combinator + let mut async_result_tracker = + AsyncResultAccessTracker(AsyncResultAccessTrackerInner::Replaying { + replay_combinators: replay_combinators.into_iter().peekable(), + }); + let combinator_result = combinator + .try_complete(&mut async_result_tracker) + .ok_or(BAD_COMBINATOR_ENTRY)?; + + // Store the ready result + async_results.insert_ready_result( + current_journal_index, + Value::CombinatorResult(combinator_result), + ); + + Ok((s, Some(AsyncResultHandle(current_journal_index)))) + } + s => { + Err(UnexpectedStateError::new(s.into(), "SysTryCompleteCombinator").into()) + } + } + } + } + } +} diff --git a/src/vm/transitions/journal.rs b/src/vm/transitions/journal.rs index 04a40fc..94dcda1 100644 --- a/src/vm/transitions/journal.rs +++ b/src/vm/transitions/journal.rs @@ -20,7 +20,7 @@ use sha2::{Digest, Sha256}; use std::{fmt, mem}; impl State { - fn check_side_effect_guard(&self) -> Result<(), VMError> { + pub(crate) fn check_side_effect_guard(&self) -> Result<(), VMError> { if let State::Processing { run_state, .. } = self { if run_state.is_running() { return Err(INSIDE_RUN); @@ -30,7 +30,7 @@ impl State { } } -struct PopJournalEntry(&'static str, M); +pub(crate) struct PopJournalEntry(pub(crate) &'static str, pub(crate) M); impl TransitionAndReturn> for State diff --git a/src/vm/transitions/mod.rs b/src/vm/transitions/mod.rs index ba9dd73..5338fb8 100644 --- a/src/vm/transitions/mod.rs +++ b/src/vm/transitions/mod.rs @@ -1,4 +1,5 @@ mod async_results; +mod combinators; mod input; mod journal; mod terminal; @@ -8,6 +9,7 @@ use crate::vm::context::Context; use crate::vm::State; use crate::{CoreVM, VMError}; pub(crate) use async_results::*; +pub(crate) use combinators::*; pub(crate) use input::*; pub(crate) use journal::*; use std::mem; diff --git a/tests/bootstrap.rs b/tests/bootstrap.rs index 5b15d16..ff52b70 100644 --- a/tests/bootstrap.rs +++ b/tests/bootstrap.rs @@ -13,8 +13,14 @@ fn bootstrap() { .protoc_arg("--experimental_allow_proto3_optional") .out_dir(out_dir.clone()) .compile_protos( - &[root_dir.join("service-protocol/dev/restate/service/protocol.proto")], - &[root_dir.join("service-protocol")], + &[ + root_dir.join("service-protocol/dev/restate/service/protocol.proto"), + root_dir.join("service-protocol-ext/combinators.proto"), + ], + &[ + root_dir.join("service-protocol"), + root_dir.join("service-protocol-ext"), + ], ) { panic!("failed to compile `console-api` protobuf: {}", error);