From aa5d95360403878c947d1e72b84d7f24cc5ce228 Mon Sep 17 00:00:00 2001 From: Advaith Nair Date: Fri, 6 Dec 2024 14:31:14 -0800 Subject: [PATCH] fix(sdk): proof request signatures (#1845) --- Cargo.lock | 1 + Cargo.toml | 1 - crates/sdk/Cargo.toml | 1 + crates/sdk/src/network-v2/client.rs | 6 +- crates/sdk/src/network-v2/json.rs | 57 ++++ crates/sdk/src/network-v2/mod.rs | 1 + crates/sdk/src/network-v2/proto/network.rs | 339 ++++++++++++++++++--- crates/sdk/src/network-v2/sign_message.rs | 33 +- 8 files changed, 397 insertions(+), 42 deletions(-) create mode 100644 crates/sdk/src/network-v2/json.rs diff --git a/Cargo.lock b/Cargo.lock index 21b1074ae0..243a8f6a1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5958,6 +5958,7 @@ dependencies = [ "reqwest", "reqwest-middleware", "serde", + "serde_json", "sp1-build", "sp1-core-executor", "sp1-core-machine", diff --git a/Cargo.toml b/Cargo.toml index f48c993410..58e56265fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,7 +68,6 @@ sp1-sdk = { path = "crates/sdk", version = "3.0.0" } sp1-cuda = { path = "crates/cuda", version = "3.0.0" } sp1-stark = { path = "crates/stark", version = "3.0.0" } sp1-lib = { path = "crates/zkvm/lib", version = "3.0.0", default-features = false } - # NOTE: The version in this crate is manually set to 3.0.1 right now. When upgrading SP1 versions, # make sure to update this crate. sp1-zkvm = { path = "crates/zkvm/entrypoint", version = "3.0.1", default-features = false } diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index a55b6caf39..f6fe5acdd3 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -12,6 +12,7 @@ categories = { workspace = true } [dependencies] prost = { version = "0.13", optional = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } twirp = { package = "twirp-rs", version = "0.13.0-succinct", optional = true } async-trait = "0.1.81" reqwest-middleware = { version = "0.3.2", optional = true } diff --git a/crates/sdk/src/network-v2/client.rs b/crates/sdk/src/network-v2/client.rs index 8bd71e2e1b..509aaad444 100644 --- a/crates/sdk/src/network-v2/client.rs +++ b/crates/sdk/src/network-v2/client.rs @@ -22,8 +22,8 @@ use crate::network_v2::proto::network::{ prover_network_client::ProverNetworkClient, CreateProgramRequest, CreateProgramRequestBody, CreateProgramResponse, FulfillmentStatus, FulfillmentStrategy, GetNonceRequest, GetProgramRequest, GetProgramResponse, GetProofRequestStatusRequest, - GetProofRequestStatusResponse, ProofMode, RequestProofRequest, RequestProofRequestBody, - RequestProofResponse, + GetProofRequestStatusResponse, MessageFormat, ProofMode, RequestProofRequest, + RequestProofRequestBody, RequestProofResponse, }; use crate::network_v2::Signable; @@ -160,6 +160,7 @@ impl NetworkClient { Ok(rpc .create_program(CreateProgramRequest { + format: MessageFormat::Binary.into(), signature: request_body.sign(&self.signer).into(), body: Some(request_body), }) @@ -232,6 +233,7 @@ impl NetworkClient { }; let request_response = rpc .request_proof(RequestProofRequest { + format: MessageFormat::Binary.into(), signature: request_body.sign(&self.signer).into(), body: Some(request_body), }) diff --git a/crates/sdk/src/network-v2/json.rs b/crates/sdk/src/network-v2/json.rs new file mode 100644 index 0000000000..a0808d9be6 --- /dev/null +++ b/crates/sdk/src/network-v2/json.rs @@ -0,0 +1,57 @@ +use prost::Message; +#[allow(unused_imports)] +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// Errors that can occur during JSON formatting. +#[derive(Error, Debug)] +pub enum JsonFormatError { + #[error("Serialization error: {0}")] + SerializationError(String), +} + +/// Formats a Protobuf body into a JSON byte representation. +pub fn format_json_message(body: &T) -> Result, JsonFormatError> +where + T: Message + Serialize, +{ + match serde_json::to_string(body) { + Ok(json_str) => { + if json_str.starts_with('"') && json_str.ends_with('"') { + let inner = &json_str[1..json_str.len() - 1]; + let unescaped = inner.replace("\\\"", "\""); + Ok(unescaped.into_bytes()) + } else { + Ok(json_str.into_bytes()) + } + } + Err(e) => Err(JsonFormatError::SerializationError(e.to_string())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use prost::Message as ProstMessage; + + // Test message for JSON formatting. + #[derive(Clone, ProstMessage, Serialize, Deserialize)] + struct TestMessage { + #[prost(string, tag = 1)] + value: String, + } + + #[test] + fn test_format_json_message_simple() { + let msg = TestMessage { value: "hello".to_string() }; + let result = format_json_message(&msg).unwrap(); + assert_eq!(result, b"{\"value\":\"hello\"}"); + } + + #[test] + fn test_format_json_message_with_quotes() { + let msg = TestMessage { value: "hello \"world\"".to_string() }; + let result = format_json_message(&msg).unwrap(); + assert_eq!(result, b"{\"value\":\"hello \\\"world\\\"\"}"); + } +} diff --git a/crates/sdk/src/network-v2/mod.rs b/crates/sdk/src/network-v2/mod.rs index e6444357b2..86f1a82d47 100644 --- a/crates/sdk/src/network-v2/mod.rs +++ b/crates/sdk/src/network-v2/mod.rs @@ -1,4 +1,5 @@ pub mod client; +mod json; pub mod prover; mod sign_message; #[rustfmt::skip] diff --git a/crates/sdk/src/network-v2/proto/network.rs b/crates/sdk/src/network-v2/proto/network.rs index cea0e67d82..dd0077db9a 100644 --- a/crates/sdk/src/network-v2/proto/network.rs +++ b/crates/sdk/src/network-v2/proto/network.rs @@ -1,11 +1,14 @@ // This file is @generated by prost-build. #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct RequestProofRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -52,11 +55,14 @@ pub struct RequestProofResponseBody { } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct FulfillProofRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -84,11 +90,14 @@ pub struct FulfillProofResponse { pub struct FulfillProofResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct ExecuteProofRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -124,11 +133,14 @@ pub struct ExecuteProofResponse { pub struct ExecuteProofResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct FailFulfillmentRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -153,11 +165,14 @@ pub struct FailFulfillmentResponse { pub struct FailFulfillmentResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct FailExecutionRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -479,11 +494,14 @@ pub struct GetFilteredDelegationsResponse { } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct AddDelegationRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -508,11 +526,14 @@ pub struct AddDelegationResponse { pub struct AddDelegationResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct RemoveDelegationRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -537,11 +558,14 @@ pub struct RemoveDelegationResponse { pub struct RemoveDelegationResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct AcceptDelegationRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -566,11 +590,14 @@ pub struct AcceptDelegationResponse { pub struct AcceptDelegationResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct SetAccountNameRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -610,6 +637,50 @@ pub struct GetAccountNameResponse { pub name: ::core::option::Option<::prost::alloc::string::String>, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] +pub struct GetTermsSignatureRequest { + /// The address of the account. + #[prost(bytes = "vec", tag = "1")] + pub address: ::prost::alloc::vec::Vec, +} +#[derive(serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, ::prost::Message)] +pub struct GetTermsSignatureResponse { + /// Whether the account has signed the terms. + #[prost(bool, tag = "1")] + pub is_signed: bool, +} +#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] +pub struct SetTermsSignatureRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, + /// The signature of the sender. + #[prost(bytes = "vec", tag = "2")] + pub signature: ::prost::alloc::vec::Vec, + /// The body of the request. + #[prost(message, optional, tag = "3")] + pub body: ::core::option::Option, +} +#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] +pub struct SetTermsSignatureRequestBody { + /// The account nonce of the sender. + #[prost(uint64, tag = "1")] + pub nonce: u64, + /// The address of the account. + #[prost(bytes = "vec", tag = "2")] + pub address: ::prost::alloc::vec::Vec, +} +#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] +pub struct SetTermsSignatureResponse { + /// The transaction hash. + #[prost(bytes = "vec", tag = "1")] + pub tx_hash: ::prost::alloc::vec::Vec, + /// The body of the response. + #[prost(message, optional, tag = "2")] + pub body: ::core::option::Option, +} +#[derive(serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, ::prost::Message)] +pub struct SetTermsSignatureResponseBody {} +#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct Program { /// The verification key hash. #[prost(bytes = "vec", tag = "1")] @@ -644,11 +715,14 @@ pub struct GetProgramResponse { } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct CreateProgramRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -679,11 +753,14 @@ pub struct CreateProgramResponse { pub struct CreateProgramResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct SetProgramNameRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -771,11 +848,14 @@ pub struct GetFilteredBalanceLogsResponse { } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct AddCreditRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -887,11 +967,14 @@ pub struct GetFilteredReservationsResponse { } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct AddReservationRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -919,11 +1002,14 @@ pub struct AddReservationResponse { pub struct AddReservationResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct RemoveReservationRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -948,11 +1034,14 @@ pub struct RemoveReservationResponse { pub struct RemoveReservationResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct BidRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -977,11 +1066,14 @@ pub struct BidResponse { pub struct BidResponseBody {} #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] pub struct SettleRequest { + /// The message format of the body. + #[prost(enumeration = "MessageFormat", tag = "1")] + pub format: i32, /// The signature of the sender. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes = "vec", tag = "2")] pub signature: ::prost::alloc::vec::Vec, /// The body of the request. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub body: ::core::option::Option, } #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, ::prost::Message)] @@ -1004,6 +1096,51 @@ pub struct SettleResponse { } #[derive(serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, ::prost::Message)] pub struct SettleResponseBody {} +/// Format to help decode signature in backend. +#[derive( + serde::Serialize, + serde::Deserialize, + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration, +)] +#[repr(i32)] +pub enum MessageFormat { + /// Unspecified message format. + UnspecifiedMessageFormat = 0, + /// The message is in binary format. + Binary = 1, + /// The message is in JSON format. + Json = 2, +} +impl MessageFormat { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::UnspecifiedMessageFormat => "UNSPECIFIED_MESSAGE_FORMAT", + Self::Binary => "BINARY", + Self::Json => "JSON", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNSPECIFIED_MESSAGE_FORMAT" => Some(Self::UnspecifiedMessageFormat), + "BINARY" => Some(Self::Binary), + "JSON" => Some(Self::Json), + _ => None, + } + } +} #[derive( serde::Serialize, serde::Deserialize, @@ -1649,7 +1786,7 @@ pub mod prover_network_client { req.extensions_mut().insert(GrpcMethod::new("network.ProverNetwork", "AddDelegation")); self.inner.unary(req, path, codec).await } - /// // Remove a delegation. Only callable by the owner of an account. + /// Remove a delegation. Only callable by the owner of an account. pub async fn remove_delegation( &mut self, request: impl tonic::IntoRequest, @@ -1727,6 +1864,46 @@ pub mod prover_network_client { req.extensions_mut().insert(GrpcMethod::new("network.ProverNetwork", "GetAccountName")); self.inner.unary(req, path, codec).await } + /// Get whether the account has signed the terms. + pub async fn get_terms_signature( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/network.ProverNetwork/GetTermsSignature"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("network.ProverNetwork", "GetTermsSignature")); + self.inner.unary(req, path, codec).await + } + /// Set whether the account has signed the terms. + pub async fn set_terms_signature( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/network.ProverNetwork/SetTermsSignature"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("network.ProverNetwork", "SetTermsSignature")); + self.inner.unary(req, path, codec).await + } /// Get metadata about a program. pub async fn get_program( &mut self, @@ -2095,7 +2272,7 @@ pub mod prover_network_server { &self, request: tonic::Request, ) -> std::result::Result, tonic::Status>; - /// // Remove a delegation. Only callable by the owner of an account. + /// Remove a delegation. Only callable by the owner of an account. async fn remove_delegation( &self, request: tonic::Request, @@ -2115,6 +2292,16 @@ pub mod prover_network_server { &self, request: tonic::Request, ) -> std::result::Result, tonic::Status>; + /// Get whether the account has signed the terms. + async fn get_terms_signature( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Set whether the account has signed the terms. + async fn set_terms_signature( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; /// Get metadata about a program. async fn get_program( &self, @@ -3060,6 +3247,90 @@ pub mod prover_network_server { }; Box::pin(fut) } + "/network.ProverNetwork/GetTermsSignature" => { + #[allow(non_camel_case_types)] + struct GetTermsSignatureSvc(pub Arc); + impl + tonic::server::UnaryService + for GetTermsSignatureSvc + { + type Response = super::GetTermsSignatureResponse; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_terms_signature(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = GetTermsSignatureSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/network.ProverNetwork/SetTermsSignature" => { + #[allow(non_camel_case_types)] + struct SetTermsSignatureSvc(pub Arc); + impl + tonic::server::UnaryService + for SetTermsSignatureSvc + { + type Response = super::SetTermsSignatureResponse; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::set_terms_signature(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = SetTermsSignatureSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } "/network.ProverNetwork/GetProgram" => { #[allow(non_camel_case_types)] struct GetProgramSvc(pub Arc); diff --git a/crates/sdk/src/network-v2/sign_message.rs b/crates/sdk/src/network-v2/sign_message.rs index f8b794a5cb..a71290e315 100644 --- a/crates/sdk/src/network-v2/sign_message.rs +++ b/crates/sdk/src/network-v2/sign_message.rs @@ -2,20 +2,26 @@ use alloy_primitives::{Address, Signature}; use prost::Message; use thiserror::Error; -use crate::network_v2::proto::network::{FulfillProofRequest, RequestProofRequest}; +use crate::network_v2::json::{format_json_message, JsonFormatError}; +use crate::network_v2::proto::network::{FulfillProofRequest, MessageFormat, RequestProofRequest}; #[allow(dead_code)] pub trait SignedMessage { fn signature(&self) -> Vec; fn nonce(&self) -> Result; fn message(&self) -> Result, MessageError>; - fn recover_sender(&self) -> Result; + fn recover_sender(&self) -> Result<(Address, Vec), RecoverSenderError>; } #[derive(Error, Debug)] +#[allow(dead_code)] pub enum MessageError { #[error("Empty message")] EmptyMessage, + #[error("JSON error: {0}")] + JsonError(String), + #[error("Binary error: {0}")] + BinaryError(String), } #[derive(Error, Debug)] @@ -43,15 +49,32 @@ macro_rules! impl_signed_message { } fn message(&self) -> Result, MessageError> { + let format = MessageFormat::try_from(self.format).unwrap_or(MessageFormat::Binary); + match &self.body { - Some(body) => Ok(body.encode_to_vec()), + Some(body) => match format { + MessageFormat::Json => format_json_message(body).map_err(|e| match e { + JsonFormatError::SerializationError(msg) => { + MessageError::JsonError(msg) + } + }), + MessageFormat::Binary => { + let proto_bytes = body.encode_to_vec(); + Ok(proto_bytes) + } + MessageFormat::UnspecifiedMessageFormat => { + let proto_bytes = body.encode_to_vec(); + Ok(proto_bytes) + } + }, None => Err(MessageError::EmptyMessage), } } - fn recover_sender(&self) -> Result { + fn recover_sender(&self) -> Result<(Address, Vec), RecoverSenderError> { let message = self.message().map_err(|_| RecoverSenderError::EmptyMessage)?; - recover_sender_raw(self.signature.clone(), message) + let sender = recover_sender_raw(self.signature.clone(), message.clone())?; + Ok((sender, message)) } } };