Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic implementation of combinators #14

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ http = { version = "1.1.0", optional = true }
googletest = "0.11.0"
test-log = { version = "0.2.16", default-features = false, features = ["trace", "color"] }
assert2 = "0.3.14"
prost-build = "0.13.2"
prost-build = "=0.13.3"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Shared core to build SDKs in various languages. Currently used by:

* [Python SDK](https://github.com/restatedev/sdk-python)
* [Rust SDK](https://github.com/restatedev/sdk-rust)

## Versions

Expand Down
22 changes: 22 additions & 0 deletions service-protocol-ext/combinators.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
*
* This file is part of the Restate SDK for Node.js/TypeScript,
* which is released under the MIT license.
*
* You can find a copy of the license in file LICENSE in the root
* directory of this repository or package, or at
* https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
*/

syntax = "proto3";

package dev.restate.service.protocol.extensions;

// Type: 0xFC00 + 2
message CombinatorEntryMessage {
repeated uint32 completed_entries_order = 1;

// Entry name
string name = 12;
}
48 changes: 46 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ mod vm;

use bytes::Bytes;
use std::borrow::Cow;
use std::fmt;
use std::time::Duration;

pub use crate::retries::RetryPolicy;
use crate::vm::AsyncResultAccessTrackerInner;
pub use headers::HeaderMap;
pub use request_identity::*;
pub use vm::CoreVM;
Expand Down Expand Up @@ -83,7 +85,7 @@ pub struct Target {
pub key: Option<String>,
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[derive(Debug, Hash, Clone, Copy, Eq, PartialEq)]
pub struct AsyncResultHandle(u32);

impl From<u32> for AsyncResultHandle {
Expand All @@ -106,6 +108,7 @@ pub enum Value {
Failure(Failure),
/// Only returned for get_state_keys
StateKeys(Vec<String>),
CombinatorResult(Vec<AsyncResultHandle>),
}

/// Terminal failure
Expand Down Expand Up @@ -162,8 +165,21 @@ pub enum TakeOutputResult {

pub type VMResult<T> = Result<T, VMError>;

pub struct VMOptions {
/// If true, false when two concurrent async results are awaited at the same time. If false, just log it.
pub fail_on_wait_concurrent_async_result: bool,
}

impl Default for VMOptions {
fn default() -> Self {
Self {
fail_on_wait_concurrent_async_result: true,
}
}
}

pub trait VM: Sized {
fn new(request_headers: impl HeaderMap) -> VMResult<Self>;
fn new(request_headers: impl HeaderMap, options: VMOptions) -> VMResult<Self>;

fn get_response_head(&self) -> ResponseHead;

Expand Down Expand Up @@ -257,6 +273,12 @@ pub trait VM: Sized {

/// Returns true if the state machine is between a sys_run_enter and sys_run_exit
fn is_inside_run(&self) -> bool;

/// Returns false if the combinator can't be completed yet.
fn sys_try_complete_combinator(
&mut self,
combinator: impl AsyncResultCombinator + fmt::Debug,
) -> VMResult<Option<AsyncResultHandle>>;
}

// HOW TO USE THIS API
Expand Down Expand Up @@ -300,5 +322,27 @@ pub trait VM: Sized {
// }
// io.close()

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AsyncResultState {
Success,
Failure,
NotReady,
}

pub struct AsyncResultAccessTracker(AsyncResultAccessTrackerInner);

impl AsyncResultAccessTracker {
pub fn get_state(&mut self, handle: AsyncResultHandle) -> AsyncResultState {
self.0.get_state(handle)
}
}

pub trait AsyncResultCombinator {
fn try_complete(
&self,
tracker: &mut AsyncResultAccessTracker,
) -> Option<Vec<AsyncResultHandle>>;
}

#[cfg(test)]
mod tests;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// This file is @generated by prost-build.
/// Type: 0xFC00 + 2
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CombinatorEntryMessage {
#[prost(uint32, repeated, tag = "1")]
pub completed_entries_order: ::prost::alloc::vec::Vec<u32>,
/// Entry name
#[prost(string, tag = "12")]
pub name: ::prost::alloc::string::String,
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ impl ServiceProtocolVersion {
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
ServiceProtocolVersion::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED",
ServiceProtocolVersion::V1 => "V1",
ServiceProtocolVersion::V2 => "V2",
Self::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED",
Self::V1 => "V1",
Self::V2 => "V2",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
Expand Down
5 changes: 5 additions & 0 deletions src/service_protocol/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub enum MessageType {
GetPromiseEntry,
PeekPromiseEntry,
CompletePromiseEntry,
CombinatorEntry,
CustomEntry(u16),
}

Expand Down Expand Up @@ -75,6 +76,7 @@ impl MessageType {
MessageType::GetPromiseEntry => MessageKind::State,
MessageType::PeekPromiseEntry => MessageKind::State,
MessageType::CompletePromiseEntry => MessageKind::State,
MessageType::CombinatorEntry => MessageKind::Syscall,
MessageType::CustomEntry(_) => MessageKind::CustomEntry,
}
}
Expand Down Expand Up @@ -127,6 +129,7 @@ const BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C02;
const AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C03;
const COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C04;
const SIDE_EFFECT_ENTRY_MESSAGE_TYPE: u16 = 0x0C05;
const COMBINATOR_ENTRY_MESSAGE_TYPE: u16 = 0xFC02;

impl From<MessageType> for MessageTypeId {
fn from(mt: MessageType) -> Self {
Expand All @@ -153,6 +156,7 @@ impl From<MessageType> for MessageTypeId {
MessageType::GetPromiseEntry => GET_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::PeekPromiseEntry => PEEK_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::CompletePromiseEntry => COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::CombinatorEntry => COMBINATOR_ENTRY_MESSAGE_TYPE,
MessageType::CustomEntry(id) => id,
}
}
Expand Down Expand Up @@ -189,6 +193,7 @@ impl TryFrom<MessageTypeId> for MessageType {
PEEK_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::PeekPromiseEntry),
COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompletePromiseEntry),
SIDE_EFFECT_ENTRY_MESSAGE_TYPE => Ok(MessageType::RunEntry),
COMBINATOR_ENTRY_MESSAGE_TYPE => Ok(MessageType::CombinatorEntry),
v if ((v & CUSTOM_MESSAGE_MASK) != 0) => Ok(MessageType::CustomEntry(v)),
v => Err(UnknownMessageType(v)),
}
Expand Down
19 changes: 19 additions & 0 deletions src/service_protocol/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl<M: CompletableEntryMessage> WriteableRestateMessage for M {
}

include!("./generated/dev.restate.service.protocol.rs");
include!("./generated/dev.restate.service.protocol.extensions.rs");

macro_rules! impl_message_traits {
($name:ident: core) => {
Expand Down Expand Up @@ -233,6 +234,24 @@ impl EntryMessageHeaderEq for RunEntryMessage {
}
}

impl_message_traits!(CombinatorEntry: message);
impl_message_traits!(CombinatorEntry: entry);
impl WriteableRestateMessage for CombinatorEntryMessage {
fn generate_header(&self, never_ack: bool) -> MessageHeader {
MessageHeader::new_ackable_entry_header(
MessageType::CombinatorEntry,
None,
if never_ack { Some(false) } else { Some(true) },
self.encoded_len() as u32,
)
}
}
impl EntryMessageHeaderEq for CombinatorEntryMessage {
fn header_eq(&self, _: &Self) -> bool {
true
}
}

// --- Completion extraction

impl TryFrom<get_state_entry_message::Result> for Value {
Expand Down
6 changes: 5 additions & 1 deletion src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ use test_log::test;

impl CoreVM {
fn mock_init(version: Version) -> CoreVM {
let vm = CoreVM::new(vec![("content-type".to_owned(), version.to_string())]).unwrap();
let vm = CoreVM::new(
vec![("content-type".to_owned(), version.to_string())],
VMOptions::default(),
)
.unwrap();

assert_that!(
vm.get_response_head().headers,
Expand Down
14 changes: 7 additions & 7 deletions src/tests/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn get_state_handler(vm: &mut CoreVM) {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice(
Expand Down Expand Up @@ -374,7 +374,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice(
Expand Down Expand Up @@ -619,7 +619,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_set(
Expand All @@ -644,7 +644,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(second_get_result))
Expand Down Expand Up @@ -799,7 +799,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_clear("STATE".to_owned()).unwrap();
Expand Down Expand Up @@ -958,7 +958,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_clear_all().unwrap();
Expand Down Expand Up @@ -1232,9 +1232,9 @@ mod state_keys {
}

let output = match h1_result.unwrap().unwrap() {
Value::Void | Value::Success(_) => panic!("Unexpected variants"),
Value::Failure(f) => NonEmptyValue::Failure(f),
Value::StateKeys(keys) => NonEmptyValue::Success(Bytes::from(keys.join(","))),
_ => panic!("Unexpected variants"),
};

vm.sys_write_output(output).unwrap();
Expand Down
22 changes: 21 additions & 1 deletion src/vm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::service_protocol::messages::{
WriteableRestateMessage,
};
use crate::service_protocol::{Encoder, MessageType, Version};
use crate::{EntryRetryInfo, VMError, Value};
use crate::{AsyncResultHandle, AsyncResultState, EntryRetryInfo, VMError, VMOptions, Value};
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
use std::collections::{HashMap, VecDeque};
Expand Down Expand Up @@ -182,6 +182,24 @@ impl AsyncResultsState {
self.ready_results.insert(idx, value);
}
}

pub(crate) fn get_ready_results_state(&self) -> HashMap<AsyncResultHandle, AsyncResultState> {
self.ready_results
.iter()
.map(|(idx, val)| {
(
AsyncResultHandle(*idx),
match val {
Value::Void
| Value::Success(_)
| Value::StateKeys(_)
| Value::CombinatorResult(_) => AsyncResultState::Success,
Value::Failure(_) => AsyncResultState::Failure,
},
)
})
.collect()
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -286,6 +304,8 @@ pub(crate) struct Context {

// Used by the error handler to set ErrorMessage.next_retry_delay
pub(crate) next_retry_delay: Option<Duration>,

pub(crate) options: VMOptions,
}

impl Context {
Expand Down
8 changes: 7 additions & 1 deletion src/vm/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod codes {
pub const UNSUPPORTED_MEDIA_TYPE: InvocationErrorCode = InvocationErrorCode(415);
pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570);
pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571);
pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572);
}

// Const errors
Expand Down Expand Up @@ -118,6 +119,11 @@ pub const INPUT_CLOSED_WHILE_WAITING_ENTRIES: VMError = VMError::new_const(
"The input was closed while still waiting to receive all the `known_entries`",
);

pub const BAD_COMBINATOR_ENTRY: VMError = VMError::new_const(
codes::PROTOCOL_VIOLATION,
"The combinator cannot be replayed. This is most likely caused by non deterministic code.",
);

// Other errors

#[derive(Debug, Clone, thiserror::Error)]
Expand Down Expand Up @@ -223,7 +229,7 @@ impl WithInvocationErrorCode for DecodingError {
}
impl_error_code!(UnavailableEntryError, JOURNAL_MISMATCH);
impl_error_code!(UnexpectedStateError, PROTOCOL_VIOLATION);
impl_error_code!(AwaitingTwoAsyncResultError, INTERNAL);
impl_error_code!(AwaitingTwoAsyncResultError, AWAITING_TWO_ASYNC_RESULTS);
impl_error_code!(BadEagerStateKeyError, INTERNAL);
impl_error_code!(DecodeStateKeysProst, PROTOCOL_VIOLATION);
impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION);
Expand Down
Loading