diff --git a/crates/batcher/src/communication.rs b/crates/batcher/src/communication.rs index 308255980f..d1784672f0 100644 --- a/crates/batcher/src/communication.rs +++ b/crates/batcher/src/communication.rs @@ -1,5 +1,3 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_batcher_types::communication::{ BatcherRequest, @@ -7,13 +5,12 @@ use starknet_batcher_types::communication::{ BatcherResponse, }; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; -use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalComponentServer; use tokio::sync::mpsc::Receiver; use crate::batcher::Batcher; pub type LocalBatcherServer = LocalComponentServer; -pub type RemoteBatcherServer = RemoteComponentServer; pub fn create_local_batcher_server( batcher: Batcher, @@ -22,14 +19,6 @@ pub fn create_local_batcher_server( LocalComponentServer::new(batcher, rx_batcher) } -pub fn create_remote_batcher_server( - batcher: Batcher, - ip_address: IpAddr, - port: u16, -) -> RemoteBatcherServer { - RemoteComponentServer::new(batcher, ip_address, port) -} - #[async_trait] impl ComponentRequestHandler for Batcher { async fn handle_request(&mut self, request: BatcherRequest) -> BatcherResponse { diff --git a/crates/consensus_manager/src/communication.rs b/crates/consensus_manager/src/communication.rs index 7dfff057e4..79e31ec9b6 100644 --- a/crates/consensus_manager/src/communication.rs +++ b/crates/consensus_manager/src/communication.rs @@ -1,5 +1,3 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_consensus_manager_types::communication::{ ConsensusManagerRequest, @@ -7,15 +5,13 @@ use starknet_consensus_manager_types::communication::{ ConsensusManagerResponse, }; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; -use starknet_mempool_infra::component_server::{LocalActiveComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalActiveComponentServer; use tokio::sync::mpsc::Receiver; use crate::consensus_manager::ConsensusManager; pub type LocalConsensusManagerServer = LocalActiveComponentServer; -pub type RemoteConsensusManagerServer = - RemoteComponentServer; pub fn create_local_consensus_manager_server( consensus_manager: ConsensusManager, @@ -24,14 +20,6 @@ pub fn create_local_consensus_manager_server( LocalActiveComponentServer::new(consensus_manager, rx_consensus_manager) } -pub fn create_remote_consensus_manager_server( - consensus_manager: ConsensusManager, - ip_address: IpAddr, - port: u16, -) -> RemoteConsensusManagerServer { - RemoteComponentServer::new(consensus_manager, ip_address, port) -} - #[async_trait] impl ComponentRequestHandler for ConsensusManager diff --git a/crates/mempool/src/communication.rs b/crates/mempool/src/communication.rs index 946bfa377f..8afdf21bac 100644 --- a/crates/mempool/src/communication.rs +++ b/crates/mempool/src/communication.rs @@ -1,10 +1,8 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_api::executable_transaction::Transaction; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_runner::ComponentStarter; -use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalComponentServer; use starknet_mempool_types::communication::{ MempoolRequest, MempoolRequestAndResponseSender, @@ -19,9 +17,6 @@ use crate::mempool::Mempool; pub type MempoolServer = LocalComponentServer; -pub type RemoteMempoolServer = - RemoteComponentServer; - pub fn create_mempool_server( mempool: Mempool, rx_mempool: Receiver, @@ -30,15 +25,6 @@ pub fn create_mempool_server( LocalComponentServer::new(communication_wrapper, rx_mempool) } -pub fn create_remote_mempool_server( - mempool: Mempool, - ip_address: IpAddr, - port: u16, -) -> RemoteMempoolServer { - let communication_wrapper = MempoolCommunicationWrapper::new(mempool); - RemoteComponentServer::new(communication_wrapper, ip_address, port) -} - /// Wraps the mempool to enable inbound async communication from other components. pub struct MempoolCommunicationWrapper { mempool: Mempool, diff --git a/crates/mempool_infra/src/component_client/local_component_client.rs b/crates/mempool_infra/src/component_client/local_component_client.rs index ce561bc3ac..39edc67978 100644 --- a/crates/mempool_infra/src/component_client/local_component_client.rs +++ b/crates/mempool_infra/src/component_client/local_component_client.rs @@ -74,7 +74,6 @@ where let (res_tx, mut res_rx) = channel::(1); let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx }; self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open."); - res_rx.recv().await.expect("Inbound connection should be open.") } } diff --git a/crates/mempool_infra/src/component_server/remote_component_server.rs b/crates/mempool_infra/src/component_server/remote_component_server.rs index 70dc099d6e..0db770ef2f 100644 --- a/crates/mempool_infra/src/component_server/remote_component_server.rs +++ b/crates/mempool_infra/src/component_server/remote_component_server.rs @@ -1,6 +1,4 @@ -use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; use async_trait::async_trait; use bincode::{deserialize, serialize}; @@ -10,14 +8,10 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode}; use serde::de::DeserializeOwned; use serde::Serialize; -use tokio::sync::Mutex; use super::definitions::ComponentServerStarter; -use crate::component_definitions::{ - ComponentRequestHandler, - ServerError, - APPLICATION_OCTET_STREAM, -}; +use crate::component_client::LocalComponentClient; +use crate::component_definitions::{ServerError, APPLICATION_OCTET_STREAM}; /// The `RemoteComponentServer` struct is a generic server that handles requests and responses for a /// specified component. It receives requests, processes them using the provided component, and @@ -47,6 +41,7 @@ use crate::component_definitions::{ /// use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter}; /// use tokio::task; /// +/// use crate::starknet_mempool_infra::component_client::LocalComponentClient; /// use crate::starknet_mempool_infra::component_definitions::ComponentRequestHandler; /// use crate::starknet_mempool_infra::component_server::{ /// ComponentServerStarter, @@ -84,17 +79,17 @@ use crate::component_definitions::{ /// /// #[tokio::main] /// async fn main() { -/// // Instantiate the component. -/// let component = MyComponent {}; +/// // Instantiate a local client to communicate with component. +/// let (tx, _rx) = tokio::sync::mpsc::channel(32); +/// let local_client = LocalComponentClient::::new(tx); /// /// // Set the ip address and port of the server's socket. /// let ip_address = std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); /// let port: u16 = 8080; /// /// // Instantiate the server. -/// let mut server = RemoteComponentServer::::new( -/// component, ip_address, port, -/// ); +/// let mut server = +/// RemoteComponentServer::::new(local_client, ip_address, port); /// /// // Start the server in a new task. /// task::spawn(async move { @@ -102,49 +97,41 @@ use crate::component_definitions::{ /// }); /// } /// ``` -pub struct RemoteComponentServer +pub struct RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, - Request: DeserializeOwned + Send + 'static, - Response: Serialize + 'static, + Request: DeserializeOwned + Send + Sync + 'static, + Response: Serialize + Send + Sync + 'static, { socket: SocketAddr, - component: Arc>, - _req: PhantomData, - _res: PhantomData, + local_client: LocalComponentClient, } -impl RemoteComponentServer +impl RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, - Request: DeserializeOwned + Send + 'static, - Response: Serialize + 'static, + Request: DeserializeOwned + Send + Sync + 'static, + Response: Serialize + Send + Sync + 'static, { - pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self { - Self { - component: Arc::new(Mutex::new(component)), - socket: SocketAddr::new(ip_address, port), - _req: PhantomData, - _res: PhantomData, - } + pub fn new( + local_client: LocalComponentClient, + ip_address: IpAddr, + port: u16, + ) -> Self { + Self { local_client, socket: SocketAddr::new(ip_address, port) } } async fn handler( http_request: HyperRequest, - component: Arc>, + local_client: LocalComponentClient, ) -> Result, hyper::Error> { let body_bytes = to_bytes(http_request.into_body()).await?; let http_response = match deserialize(&body_bytes) { - Ok(component_request) => { - // Acquire the lock for component computation, release afterwards. - let component_response = - { component.lock().await.handle_request(component_request).await }; + Ok(request) => { + let response = local_client.send(request).await; HyperResponse::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( - serialize(&component_response) - .expect("Response serialization should succeed"), + serialize(&response).expect("Response serialization should succeed"), )) } Err(error) => { @@ -161,19 +148,17 @@ where } #[async_trait] -impl ComponentServerStarter - for RemoteComponentServer +impl ComponentServerStarter for RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, Request: DeserializeOwned + Send + Sync + 'static, Response: Serialize + Send + Sync + 'static, { async fn start(&mut self) { let make_svc = make_service_fn(|_conn| { - let component = Arc::clone(&self.component); + let local_client = self.local_client.clone(); async { Ok::<_, hyper::Error>(service_fn(move |req| { - Self::handler(req, Arc::clone(&component)) + Self::handler(req, local_client.clone()) })) } }); diff --git a/crates/mempool_infra/tests/remote_component_client_server_test.rs b/crates/mempool_infra/tests/remote_component_client_server_test.rs index eea574cdf2..a1a032a763 100644 --- a/crates/mempool_infra/tests/remote_component_client_server_test.rs +++ b/crates/mempool_infra/tests/remote_component_client_server_test.rs @@ -22,13 +22,24 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use rstest::rstest; use serde::Serialize; -use starknet_mempool_infra::component_client::{ClientError, ClientResult, RemoteComponentClient}; +use starknet_mempool_infra::component_client::{ + ClientError, + ClientResult, + LocalComponentClient, + RemoteComponentClient, +}; use starknet_mempool_infra::component_definitions::{ + ComponentRequestAndResponseSender, ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM, }; -use starknet_mempool_infra::component_server::{ComponentServerStarter, RemoteComponentServer}; +use starknet_mempool_infra::component_server::{ + ComponentServerStarter, + LocalComponentServer, + RemoteComponentServer, +}; +use tokio::sync::mpsc::channel; use tokio::sync::Mutex; use tokio::task; @@ -108,10 +119,10 @@ impl ComponentRequestHandler for Componen } async fn verify_error( - a_client: impl ComponentAClientTrait, + a_remote_client: impl ComponentAClientTrait, expected_error_contained_keywords: &[&str], ) { - let Err(error) = a_client.a_get_value().await else { + let Err(error) = a_remote_client.a_get_value().await else { panic!("Expected an error."); }; assert_error_contains_keywords(error.to_string(), expected_error_contained_keywords) @@ -156,29 +167,41 @@ where } async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) { - let a_client = ComponentAClient::new(LOCAL_IP, a_port, MAX_RETRIES); - let b_client = ComponentBClient::new(LOCAL_IP, b_port, MAX_RETRIES); - - let component_a = ComponentA::new(Box::new(b_client)); - let component_b = ComponentB::new(setup_value, Box::new(a_client.clone())); - - let mut component_a_server = RemoteComponentServer::< - ComponentA, - ComponentARequest, - ComponentAResponse, - >::new(component_a, LOCAL_IP, a_port); - let mut component_b_server = RemoteComponentServer::< - ComponentB, - ComponentBRequest, - ComponentBResponse, - >::new(component_b, LOCAL_IP, b_port); + let a_remote_client = ComponentAClient::new(LOCAL_IP, a_port, MAX_RETRIES); + let b_remote_client = ComponentBClient::new(LOCAL_IP, b_port, MAX_RETRIES); + + let component_a = ComponentA::new(Box::new(b_remote_client)); + let component_b = ComponentB::new(setup_value, Box::new(a_remote_client.clone())); + + let (tx_a, rx_a) = + channel::>(32); + let (tx_b, rx_b) = + channel::>(32); + + let a_local_client = LocalComponentClient::::new(tx_a); + let b_local_client = LocalComponentClient::::new(tx_b); + + let mut component_a_local_server = LocalComponentServer::new(component_a, rx_a); + let mut component_b_local_server = LocalComponentServer::new(component_b, rx_b); + + let mut component_a_remote_server = + RemoteComponentServer::new(a_local_client, LOCAL_IP, a_port); + let mut component_b_remote_server = + RemoteComponentServer::new(b_local_client, LOCAL_IP, b_port); + + task::spawn(async move { + component_a_local_server.start().await; + }); + task::spawn(async move { + component_b_local_server.start().await; + }); task::spawn(async move { - component_a_server.start().await; + component_a_remote_server.start().await; }); task::spawn(async move { - component_b_server.start().await; + component_b_remote_server.start().await; }); // Todo(uriel): Get rid of this @@ -189,9 +212,9 @@ async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) { async fn test_proper_setup() { let setup_value: ValueB = 90; setup_for_tests(setup_value, A_PORT_TEST_SETUP, B_PORT_TEST_SETUP).await; - let a_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP, MAX_RETRIES); - let b_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP, MAX_RETRIES); - test_a_b_functionality(a_client, b_client, setup_value.into()).await; + let a_remote_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP, MAX_RETRIES); + let b_remote_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP, MAX_RETRIES); + test_a_b_functionality(a_remote_client, b_remote_client, setup_value.into()).await; } #[tokio::test]