From e678e92bf01bc4bc914e74b6fed22c8b55b3cdc7 Mon Sep 17 00:00:00 2001 From: zeroqn Date: Wed, 12 Aug 2020 15:01:44 +0800 Subject: [PATCH] feat(network): verify chain id during protocol handshake (#406) * feat(network): add chain id identification to identify protocol * feat(network): discovery protocol require identification procedure Discovery protocol will wait identification result in identify protocol before accept multiaddrs from remote session. * fix(network): remove wrong dep crate * fix(network): discovery protocol compilation * feat(network): set_chain_id fn * chore(network): disable_chain_id_check feature for trust metric tests * fix(network): clippy warnings * fix: trust metric integrate test * Revert "feat(network): discovery protocol require identification procedure" This reverts commit 7570083e875d7c1a74e075ac3aa7b4366e4e8d39. * chore(network): remove identify protocol usage in discovery protocol * feat(network): expose wait_identified fn from identify protocol * feat(network): implement new UnidentifiedSession in PeerManagerEvent * feat(network): open other protocols after session identified * feat(network): if session is unidentified, close discovery protocol * fix(network): compilation failed * fix(network): unidentified session isn't removed on disconnected * fix(network): SessionInfoNotFound error * change(network): dail identify protocol first * change(network): open protocols from client side * refactor(network): identify protocol Client must wait an ack message before open other protocols * fix(network): clippy warnings * chore * fix: try fix trust metric test * fix(network): trust metric integrated test failed * fix(trust_metric_test): use random generate seckey for full node network Since identify protocol use lazy static HashMap, we need different peer id for each test. * refactor(network): identify protocol * feat(network): set up a timeout check for peer protocol open state * feat(network): more detailed warn message * test(network): add chain id verification integrated test * chore(network): remove unused debug log * feat(network): implement session pre-check for unidentified session So that we can reject an new session earlier * fix(network): IdentifyProtocol always insert new identification * fix(network): clippy warnings * feat(network): limit message size in identify protocol * feat(network): add FailedWithExceedMsgSize state * refactor(network): validate identify message before finish identify * fix: clippy warnings * test(network): add one identify unit test * test(network): more identify unit tests * test(network): more identify unit tests * test(network): identify protocol unit tests * test(network): add UnidentifiedSession event tests * fix(network): unused warnings in cfg(test) * test(network): add identify protocol disconnected test * fix(network): unused warnings * chore: format code * fix(network): clippy warnings * change!(network): bump identify protocol version * change(network): transmitter protocol should also check session state If session is not accepted, should close protocol stream instead. --- Cargo.toml | 5 + core/network/src/connection/keeper.rs | 4 +- core/network/src/event.rs | 13 + core/network/src/peer_manager/mod.rs | 225 +++++- core/network/src/peer_manager/test_manager.rs | 154 ++++ core/network/src/protocols/core.rs | 109 ++- core/network/src/protocols/discovery.rs | 4 +- .../src/protocols/discovery/behaviour.rs | 58 +- .../src/protocols/discovery/protocol.rs | 39 +- core/network/src/protocols/identify.rs | 63 +- .../src/protocols/identify/behaviour.rs | 244 +++--- .../src/protocols/identify/identification.rs | 181 +++++ .../network/src/protocols/identify/message.rs | 171 ++++- .../src/protocols/identify/protocol.rs | 617 ++++++++++++---- core/network/src/protocols/identify/tests.rs | 694 ++++++++++++++++++ core/network/src/protocols/mod.rs | 4 +- core/network/src/protocols/ping/protocol.rs | 6 +- core/network/src/protocols/transmitter.rs | 13 +- .../src/protocols/transmitter/protocol.rs | 23 +- core/network/src/service.rs | 7 +- core/network/src/test/mock.rs | 142 +++- src/default_start.rs | 3 + tests/common/mod.rs | 40 + tests/{trust_metric_all => common}/node.rs | 6 +- .../node/config.rs | 0 .../node/consts.rs | 0 .../node/diagnostic.rs | 0 .../node/full_node.rs | 4 +- .../node/full_node/builder.rs | 3 +- .../node/full_node/default_start.rs | 11 +- .../node/full_node/error.rs | 0 .../node/full_node/memory_db.rs | 0 .../{trust_metric_all => common}/node/sync.rs | 6 + tests/trust_metric.rs | 1 + .../{node => }/client_node.rs | 68 +- tests/trust_metric_all/common.rs | 45 +- tests/trust_metric_all/consensus.rs | 5 +- tests/trust_metric_all/mempool.rs | 6 +- tests/trust_metric_all/mod.rs | 25 +- tests/verify_chain_id.rs | 285 +++++++ 40 files changed, 2763 insertions(+), 521 deletions(-) create mode 100644 core/network/src/protocols/identify/identification.rs create mode 100644 core/network/src/protocols/identify/tests.rs create mode 100644 tests/common/mod.rs rename tests/{trust_metric_all => common}/node.rs (61%) rename tests/{trust_metric_all => common}/node/config.rs (100%) rename tests/{trust_metric_all => common}/node/consts.rs (100%) rename tests/{trust_metric_all => common}/node/diagnostic.rs (100%) rename tests/{trust_metric_all => common}/node/full_node.rs (95%) rename tests/{trust_metric_all => common}/node/full_node/builder.rs (96%) rename tests/{trust_metric_all => common}/node/full_node/default_start.rs (98%) rename tests/{trust_metric_all => common}/node/full_node/error.rs (100%) rename tests/{trust_metric_all => common}/node/full_node/memory_db.rs (100%) rename tests/{trust_metric_all => common}/node/sync.rs (97%) rename tests/trust_metric_all/{node => }/client_node.rs (88%) create mode 100644 tests/verify_chain_id.rs diff --git a/Cargo.toml b/Cargo.toml index c6f7bc7f8..9966ba1c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,6 +94,11 @@ name = "trust_metric" path = "tests/trust_metric.rs" required-features = [ "core-network/diagnostic" ] +[[test]] +name = "verify_chain_id" +path = "tests/verify_chain_id.rs" +required-features = [ "core-network/diagnostic" ] + [[bench]] name = "bench_execute" path = "benchmark/mod.rs" diff --git a/core/network/src/connection/keeper.rs b/core/network/src/connection/keeper.rs index d0eea4c85..f915f0668 100644 --- a/core/network/src/connection/keeper.rs +++ b/core/network/src/connection/keeper.rs @@ -222,9 +222,9 @@ impl ServiceHandle for ConnectionServiceKeeper { let pid = pubkey.peer_id(); #[cfg(test)] let session_context = SessionContext::from(session_context).arced(); - let new_peer_session = PeerManagerEvent::NewSession { pid, pubkey, ctx: session_context }; + let new_unidentified_session = PeerManagerEvent::UnidentifiedSession { pid, pubkey, ctx: session_context }; - self.report_peer(new_peer_session); + self.report_peer(new_unidentified_session); } ServiceEvent::SessionClose { session_context } => { let pid = peer_pubkey!(&session_context).peer_id(); diff --git a/core/network/src/event.rs b/core/network/src/event.rs index 146de597f..582070b34 100644 --- a/core/network/src/event.rs +++ b/core/network/src/event.rs @@ -135,6 +135,19 @@ pub enum PeerManagerEvent { ctx: Arc, }, + #[display( + fmt = "unidentified session {} peer {:?} addr {} ty {:?}", + "ctx.id", + pid, + "ctx.address", + "ctx.ty" + )] + UnidentifiedSession { + pid: PeerId, + pubkey: PublicKey, + ctx: Arc, + }, + #[display(fmt = "repeated connection type {} session {} addr {}", ty, sid, addr)] RepeatedConnection { ty: ConnectionType, diff --git a/core/network/src/peer_manager/mod.rs b/core/network/src/peer_manager/mod.rs index d6101c9e9..62cd87cf9 100644 --- a/core/network/src/peer_manager/mod.rs +++ b/core/network/src/peer_manager/mod.rs @@ -13,17 +13,6 @@ mod trust_metric; #[cfg(feature = "diagnostic")] pub mod diagnostic; -use addr_set::PeerAddrSet; -use retry::Retry; -use save_restore::{NoPeerDatFile, PeerDatFile, SaveRestore}; -use session_book::{AcceptableSession, ArcSession, SessionContext}; -use tags::Tags; - -pub use peer::{ArcPeer, Connectedness}; -pub use session_book::SessionBook; -pub use shared::SharedSessions; -pub use trust_metric::{TrustMetric, TrustMetricConfig}; - #[cfg(test)] mod test_manager; @@ -42,6 +31,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use arc_swap::ArcSwap; use derive_more::Display; use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures::stream::Stream; @@ -62,8 +52,20 @@ use crate::event::{ ConnectionErrorKind, ConnectionEvent, ConnectionType, MisbehaviorKind, PeerManagerEvent, SessionErrorKind, }; +use crate::protocols::identify::{Identify, WaitIdentification}; use crate::traits::MultiaddrExt; +use addr_set::PeerAddrSet; +use retry::Retry; +use save_restore::{NoPeerDatFile, PeerDatFile, SaveRestore}; +use session_book::{AcceptableSession, ArcSession, SessionContext}; +use tags::Tags; + +pub use peer::{ArcPeer, Connectedness}; +pub use session_book::SessionBook; +pub use shared::SharedSessions; +pub use trust_metric::{TrustMetric, TrustMetricConfig}; + const SAME_IP_LIMIT_BAN: Duration = Duration::from_secs(5 * 60); const REPEATED_CONNECTION_TIMEOUT: u64 = 30; // seconds const BACKOFF_BASE: u64 = 2; @@ -75,6 +77,24 @@ const MAX_CONNECTING_MARGIN: usize = 10; const GOOD_TRUST_SCORE: u8 = 80u8; const WORSE_TRUST_SCALAR_RATIO: usize = 10; +#[derive(Debug, Display)] +pub enum NewSessionPreCheckError { + #[display(fmt = "peer banned")] + PeerBanned, + + #[display(fmt = "allow list peer only")] + AllowListOnly, + + #[display(fmt = "reach max connection")] + ReachMaxConnection, + + #[display(fmt = "peer already connected, only allow one connection per peer")] + PeerAlreadyConnected, + + #[display(fmt = "{}", _0)] + ReachSessionLimit(session_book::Error), +} + #[derive(Debug, Clone, Display, Serialize, Deserialize)] #[display(fmt = "{}", _0)] pub struct PeerMultiaddr(Multiaddr); @@ -193,7 +213,8 @@ impl Hash for ConnectingAttempt { } struct Inner { - our_id: Arc, + our_id: Arc, + chain_id: ArcSwap, sessions: SessionBook, consensus: RwLock>, @@ -206,6 +227,7 @@ impl Inner { pub fn new(our_id: PeerId, sessions: SessionBook) -> Self { Inner { our_id: Arc::new(our_id), + chain_id: ArcSwap::new(Arc::new(protocol::types::Hash::from_empty())), sessions, consensus: Default::default(), @@ -227,6 +249,14 @@ impl Inner { self.listen.write().remove(multiaddr); } + pub fn set_chain_id(&self, chain_id: protocol::types::Hash) { + self.chain_id.store(Arc::new(chain_id)); + } + + pub fn chain_id(&self) -> Arc { + self.chain_id.load_full() + } + pub fn connected(&self) -> usize { self.sessions.len() } @@ -293,6 +323,36 @@ impl Inner { } } +struct UnidentifiedSessionEvent { + pubkey: PublicKey, + ctx: Arc, +} + +struct UnidentifiedSession { + event: UnidentifiedSessionEvent, + ident_fut: WaitIdentification, +} + +impl Borrow for UnidentifiedSession { + fn borrow(&self) -> &SessionId { + &self.event.ctx.id + } +} + +impl PartialEq for UnidentifiedSession { + fn eq(&self, other: &UnidentifiedSession) -> bool { + self.event.ctx.id == other.event.ctx.id + } +} + +impl Eq for UnidentifiedSession {} + +impl Hash for UnidentifiedSession { + fn hash(&self, state: &mut H) { + self.event.ctx.id.hash(state) + } +} + #[derive(Debug, Clone)] pub struct PeerManagerConfig { /// Our Peer ID @@ -343,6 +403,18 @@ impl PeerManagerHandle { self.inner.session(sid).map(|s| s.peer.owned_id()) } + pub fn set_chain_id(&self, chain_id: protocol::types::Hash) { + self.inner.set_chain_id(chain_id); + } + + pub fn chain_id(&self) -> Arc { + self.inner.chain_id() + } + + pub fn contains_session(&self, session_id: SessionId) -> bool { + self.inner.session(session_id).is_some() + } + pub fn random_addrs(&self, max: usize, sid: SessionId) -> Vec { let mut rng = rand::thread_rng(); let book = self.inner.peers.read(); @@ -453,6 +525,9 @@ pub struct PeerManager { // peers currently connecting connecting: HashSet, + // unidentified session backlog + unidentified_backlog: HashSet, + event_rx: UnboundedReceiver, conn_tx: UnboundedSender, @@ -500,6 +575,7 @@ impl PeerManager { bootstraps, connecting: Default::default(), + unidentified_backlog: Default::default(), event_rx, conn_tx, @@ -586,7 +662,11 @@ impl PeerManager { } } - fn new_session(&mut self, pubkey: PublicKey, ctx: Arc) { + fn new_session_pre_check( + &mut self, + pubkey: &PublicKey, + ctx: &Arc, + ) -> Result { let remote_peer_id = pubkey.peer_id(); let remote_multiaddr = PeerMultiaddr::new(ctx.address.to_owned(), &remote_peer_id); @@ -595,13 +675,6 @@ impl PeerManager { let opt_peer = self.inner.peer(&remote_peer_id); let remote_peer = opt_peer.unwrap_or_else(|| ArcPeer::new(remote_peer_id.clone())); - if !remote_peer.has_pubkey() { - if let Err(e) = remote_peer.set_pubkey(pubkey) { - error!("impossible, set public key failed {}", e); - error!("new session without peer pubkey, chain book will not be updated"); - } - } - // Inbound address is client address, it's useless match ctx.ty { SessionType::Inbound => remote_peer.multiaddrs.remove(&remote_multiaddr), @@ -618,7 +691,7 @@ impl PeerManager { info!("banned peer {:?} incomming", remote_peer_id); remote_peer.mark_disconnected(); self.disconnect_session(ctx.id); - return; + return Err(NewSessionPreCheckError::PeerBanned); } if self.config.allowlist_only @@ -628,7 +701,7 @@ impl PeerManager { debug!("allowlist_only enabled, reject peer {:?}", remote_peer.id); remote_peer.mark_disconnected(); self.disconnect_session(ctx.id); - return; + return Err(NewSessionPreCheckError::AllowListOnly); } if self.inner.connected() >= self.config.max_connections { @@ -656,6 +729,10 @@ impl PeerManager { && session.peer.alive() > self.config.peer_trust_config.interval().as_secs() * 20 { + info!( + "session peer {:?} is been replaced by peer {:?}", + session.peer.id, remote_peer.id + ); self.disconnect_session(session.id); return true; } @@ -668,9 +745,11 @@ impl PeerManager { && !remote_peer.tags.contains(&PeerTag::Consensus) && !found_replacement() { + info!("reject peer {:?} due to max conn limit", remote_peer.id); + remote_peer.mark_disconnected(); self.disconnect_session(ctx.id); - return; + return Err(NewSessionPreCheckError::ReachMaxConnection); } } @@ -683,7 +762,7 @@ impl PeerManager { if exist_sid != ctx.id && self.inner.session(exist_sid).is_some() { // We don't support multiple connections, disconnect new one self.disconnect_session(ctx.id); - return; + return Err(NewSessionPreCheckError::PeerAlreadyConnected); } if self.inner.session(exist_sid).is_none() { @@ -711,16 +790,51 @@ impl PeerManager { remote_peer.mark_disconnected(); self.disconnect_session(ctx.id); + return Err(NewSessionPreCheckError::ReachSessionLimit(err)); + } + } + + Ok(session) + } + + fn new_unidentified_session(&mut self, pubkey: PublicKey, ctx: Arc) { + let peer_id = pubkey.peer_id(); + if let Err(err) = self.new_session_pre_check(&pubkey, &ctx) { + log::info!("reject unidentified session due to {}", err); + + Identify::wait_failed(&peer_id, err.to_string()); + return; + } + + let event = UnidentifiedSessionEvent { pubkey, ctx }; + let ident_fut = Identify::wait_identified(peer_id); + let unidentified_session = UnidentifiedSession { event, ident_fut }; + + self.unidentified_backlog.insert(unidentified_session); + } + + fn new_session(&mut self, pubkey: PublicKey, ctx: Arc) { + let session = match self.new_session_pre_check(&pubkey, &ctx) { + Ok(session) => session, + Err(err) => { + log::info!("reject new session due to {}", err); return; } + }; + + if !session.peer.has_pubkey() { + if let Err(e) = session.peer.set_pubkey(pubkey) { + error!("impossible, set public key failed {}", e); + } } // Currently we only save accepted peer. // TODO: save to database - if !self.inner.contains(&remote_peer_id) { - self.inner.add_peer(remote_peer.clone()); + if !self.inner.contains(&session.peer.id) { + self.inner.add_peer(session.peer.clone()); } + let remote_peer = session.peer.clone(); self.inner.sessions.insert(AcceptableSession(session)); remote_peer.mark_connected(ctx.id); @@ -738,6 +852,11 @@ impl PeerManager { fn session_closed(&mut self, sid: SessionId) { debug!("session {} closed", sid); + // Unidentified session + if self.unidentified_backlog.take(&sid).is_some() { + return; + } + let session = match self.inner.remove_session(sid) { Some(s) => s, None => return, /* Session may be removed by other event or rejected @@ -841,10 +960,9 @@ impl PeerManager { } fn session_failed(&self, sid: SessionId, error_kind: SessionErrorKind) { + warn!("session {} failed {}", sid, error_kind); use SessionErrorKind::{Io, Protocol, Unexpected}; - debug!("session {} failed", sid); - let session = match self.inner.remove_session(sid) { Some(s) => s, None => return, /* Session may be removed by other event or rejected @@ -893,6 +1011,7 @@ impl PeerManager { } fn peer_misbehave(&self, pid: PeerId, kind: MisbehaviorKind) { + warn!("peer {:?} misbehave {}", pid, kind); use MisbehaviorKind::{Discovery, PingTimeout, PingUnexpect}; let peer = match self.inner.peer(&pid) { @@ -1172,6 +1291,9 @@ impl PeerManager { match event { PeerManagerEvent::ConnectPeersNow { pids } => self.connect_peers_by_id(pids), PeerManagerEvent::ConnectFailed { addr, kind } => self.connect_failed(addr, kind), + PeerManagerEvent::UnidentifiedSession { pubkey, ctx, .. } => { + self.new_unidentified_session(pubkey, ctx) + } PeerManagerEvent::NewSession { pubkey, ctx, .. } => self.new_session(pubkey, ctx), // NOTE: Alice may disconnect to Bob, but bob didn't know // that, so the next time, Alice try to connect to Bob will @@ -1224,6 +1346,53 @@ impl Future for PeerManager { tokio::spawn(heart_beat); } + // Process unidentified sessions + let unidentified_sessions = self.unidentified_backlog.drain().collect::>(); + for mut session in unidentified_sessions { + let ident_fut = &mut session.ident_fut; + futures::pin_mut!(ident_fut); + + match ident_fut.poll(ctx) { + Poll::Pending => { + self.unidentified_backlog.insert(session); + } + Poll::Ready(ret) => match ret { + Ok(()) => { + let UnidentifiedSession { event, .. } = session; + let new_session_event = PeerManagerEvent::NewSession { + pid: event.pubkey.peer_id(), + pubkey: event.pubkey, + ctx: event.ctx, + }; + + // TODO: Remove duplicate diag code + #[cfg(feature = "diagnostic")] + let diag_event: Option< + diagnostic::DiagnosticEvent, + > = From::from(&new_session_event); + + self.process_event(new_session_event); + + #[cfg(feature = "diagnostic")] + if let (Some(hook), Some(event)) = + (self.diagnostic_hook.as_ref(), diag_event) + { + hook(event) + } + } + Err(err) => { + warn!( + "reject peer {:?} due to identification failed: {}", + session.event.pubkey.peer_id(), + err + ); + + self.disconnect_session(session.event.ctx.id); + } + }, + } + } + // Process manager events loop { let event_rx = &mut self.as_mut().event_rx; diff --git a/core/network/src/peer_manager/test_manager.rs b/core/network/src/peer_manager/test_manager.rs index 615277d22..438b3d030 100644 --- a/core/network/src/peer_manager/test_manager.rs +++ b/core/network/src/peer_manager/test_manager.rs @@ -3049,3 +3049,157 @@ async fn should_accept_peer_in_allowlist_even_reach_inbound_conn_limit() { "should accept peer in allowlist" ); } + +#[tokio::test] +async fn should_reject_new_connection_for_same_peer_on_unidentified_session() { + let (mut mgr, mut conn_rx) = make_manager(0, 20); + let remote_peers = make_sessions(&mut mgr, 1, 5000, SessionType::Outbound).await; + + let test_peer = remote_peers.first().expect("get first peer"); + let sess_ctx = SessionContext::make( + SessionId::new(99), + test_peer.multiaddrs.all_raw().pop().expect("get multiaddr"), + SessionType::Outbound, + test_peer.owned_pubkey().expect("pubkey"), + ); + let new_session = PeerManagerEvent::UnidentifiedSession { + pid: test_peer.owned_id(), + pubkey: test_peer.owned_pubkey().expect("pubkey"), + ctx: sess_ctx.arced(), + }; + mgr.poll_event(new_session).await; + + let conn_event = conn_rx.next().await.expect("should have disconnect event"); + match conn_event { + ConnectionEvent::Disconnect(sid) => assert_eq!(sid, 99.into(), "should be new session id"), + _ => panic!("should be disconnect event"), + } +} + +#[tokio::test] +async fn should_reject_same_ip_connection_when_reach_limit_on_unidentified_session() { + let manager_pubkey = make_pubkey(); + let manager_id = manager_pubkey.peer_id(); + let mut peer_dat_file = std::env::temp_dir(); + peer_dat_file.push("peer.dat"); + let peer_trust_config = Arc::new(TrustMetricConfig::default()); + let peer_fatal_ban = Duration::from_secs(50); + let peer_soft_ban = Duration::from_secs(10); + + let config = PeerManagerConfig { + our_id: manager_id, + pubkey: manager_pubkey, + bootstraps: Default::default(), + allowlist: vec![], + allowlist_only: false, + peer_trust_config, + peer_fatal_ban, + peer_soft_ban, + max_connections: 10, + same_ip_conn_limit: 1, + inbound_conn_limit: 5, + outbound_conn_limit: 5, + routine_interval: Duration::from_secs(10), + peer_dat_file, + }; + + let (conn_tx, mut conn_rx) = unbounded(); + let (mgr_tx, mgr_rx) = unbounded(); + let manager = PeerManager::new(config, mgr_rx, conn_tx); + + let mut mgr = MockManager::new(manager, mgr_tx); + make_sessions(&mut mgr, 1, 5000, SessionType::Outbound).await; + + let same_ip_peer = make_peer(9527); + + // Save same ip peer + let inner = mgr.core_inner(); + inner.add_peer(same_ip_peer.clone()); + + let sess_ctx = SessionContext::make( + SessionId::new(99), + same_ip_peer.multiaddrs.all_raw().pop().unwrap(), + SessionType::Outbound, + same_ip_peer.owned_pubkey().expect("pubkey"), + ); + let new_session = PeerManagerEvent::UnidentifiedSession { + pid: same_ip_peer.owned_id(), + pubkey: same_ip_peer.owned_pubkey().expect("pubkey"), + ctx: sess_ctx.arced(), + }; + mgr.poll_event(new_session).await; + + let inserted_same_ip_peer = inner.peer(&same_ip_peer.id).unwrap(); + assert_eq!( + inserted_same_ip_peer.tags.get_banned_until(), + Some(time::now() + SAME_IP_LIMIT_BAN.as_secs()) + ); + + let conn_event = conn_rx.next().await.expect("should have disconnect event"); + match conn_event { + ConnectionEvent::Disconnect(sid) => assert_eq!(sid, 99.into(), "should be new session id"), + _ => panic!("should be disconnect event"), + } +} + +#[tokio::test] +async fn should_accept_always_allow_peer_even_if_we_reach_max_connections_on_unidentified_session() +{ + let (mut mgr, mut conn_rx) = make_manager(0, 10); + let _remote_peers = make_sessions(&mut mgr, 10, 5000, SessionType::Outbound).await; + + let peer = make_peer(2019); + let always_allow_peer = make_peer(2077); + always_allow_peer.tags.insert(PeerTag::AlwaysAllow).unwrap(); + + let inner = mgr.core_inner(); + inner.add_peer(always_allow_peer.clone()); + + assert_eq!(inner.connected(), 10, "should have 10 connections"); + + // First one without AlwaysAllow tag + let sess_ctx = SessionContext::make( + SessionId::new(233), + peer.multiaddrs.all_raw().pop().expect("peer multiaddr"), + SessionType::Inbound, + peer.owned_pubkey().expect("pubkey"), + ); + let new_session = PeerManagerEvent::UnidentifiedSession { + pid: peer.owned_id(), + pubkey: peer.owned_pubkey().expect("pubkey"), + ctx: sess_ctx.arced(), + }; + mgr.poll_event(new_session).await; + let conn_event = conn_rx.next().await.expect("should have disconnect event"); + match conn_event { + ConnectionEvent::Disconnect(sid) => assert_eq!(sid, 233.into(), "should be new session id"), + _ => panic!("should be disconnect event"), + } + + // Now peer has AlwaysAllow tag + let sess_ctx = SessionContext::make( + SessionId::new(666), + always_allow_peer + .multiaddrs + .all_raw() + .pop() + .expect("peer multiaddr"), + SessionType::Inbound, + always_allow_peer + .owned_pubkey() + .expect("always allow peer's pubkey"), + ); + let new_session = PeerManagerEvent::UnidentifiedSession { + pid: always_allow_peer.owned_id(), + pubkey: always_allow_peer + .owned_pubkey() + .expect("always allow peer's pubkey"), + ctx: sess_ctx.arced(), + }; + mgr.poll_event(new_session).await; + + match conn_rx.try_next() { + Err(_) => (), // Err means channel is empty, it's expected + _ => panic!("should not have any disconnect event"), + } +} diff --git a/core/network/src/protocols/core.rs b/core/network/src/protocols/core.rs index a72a4de16..9c461bb2a 100644 --- a/core/network/src/protocols/core.rs +++ b/core/network/src/protocols/core.rs @@ -1,6 +1,11 @@ +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; use std::time::Duration; use futures::channel::mpsc::UnboundedSender; +use lazy_static::lazy_static; +use parking_lot::RwLock; +use tentacle::secio::PeerId; use tentacle::service::{ProtocolMeta, TargetProtocol}; use tentacle::ProtocolId; @@ -17,6 +22,53 @@ pub const IDENTIFY_PROTOCOL_ID: usize = 2; pub const DISCOVERY_PROTOCOL_ID: usize = 3; pub const TRANSMITTER_PROTOCOL_ID: usize = 4; +lazy_static! { + // NOTE: Use peer id here because trust metric integrated test run in one process + static ref PEER_OPENED_PROTOCOLS: RwLock>> = RwLock::new(HashMap::new()); +} + +pub struct OpenedProtocols {} + +impl OpenedProtocols { + pub fn register(peer_id: PeerId, proto_id: ProtocolId) { + PEER_OPENED_PROTOCOLS + .write() + .entry(peer_id) + .and_modify(|protos| { + protos.insert(proto_id); + }) + .or_insert_with(|| HashSet::from_iter(vec![proto_id])); + } + + #[allow(dead_code)] + pub fn unregister(peer_id: &PeerId, proto_id: ProtocolId) { + if let Some(ref mut proto_ids) = PEER_OPENED_PROTOCOLS.write().get_mut(peer_id) { + proto_ids.remove(&proto_id); + } + } + + pub fn remove(peer_id: &PeerId) { + PEER_OPENED_PROTOCOLS.write().remove(peer_id); + } + + #[cfg(test)] + pub fn is_open(peer_id: &PeerId, proto_id: &ProtocolId) -> bool { + PEER_OPENED_PROTOCOLS + .read() + .get(peer_id) + .map(|ids| ids.contains(proto_id)) + .unwrap_or_else(|| false) + } + + pub fn is_all_opened(peer_id: &PeerId) -> bool { + PEER_OPENED_PROTOCOLS + .read() + .get(peer_id) + .map(|ids| ids.len() == 4) + .unwrap_or_else(|| false) + } +} + #[derive(Default)] pub struct CoreProtocolBuilder { ping: Option, @@ -42,12 +94,7 @@ impl CoreProtocol { impl NetworkProtocol for CoreProtocol { fn target() -> TargetProtocol { - TargetProtocol::Multi(vec![ - ProtocolId::new(PING_PROTOCOL_ID), - ProtocolId::new(IDENTIFY_PROTOCOL_ID), - ProtocolId::new(DISCOVERY_PROTOCOL_ID), - ProtocolId::new(TRANSMITTER_PROTOCOL_ID), - ]) + TargetProtocol::Single(ProtocolId::new(IDENTIFY_PROTOCOL_ID)) } fn metas(self) -> Vec { @@ -100,8 +147,12 @@ impl CoreProtocolBuilder { self } - pub fn transmitter(mut self, data_tx: UnboundedSender) -> Self { - let transmitter = Transmitter::new(data_tx); + pub fn transmitter( + mut self, + bytes_tx: UnboundedSender, + peer_mgr: PeerManagerHandle, + ) -> Self { + let transmitter = Transmitter::new(bytes_tx, peer_mgr); self.transmitter = Some(transmitter); self @@ -117,32 +168,20 @@ impl CoreProtocolBuilder { transmitter, } = self; - // Panic for missing protocol - assert!(ping.is_some(), "init: missing protocol ping"); - assert!(identify.is_some(), "init: missing protocol identify"); - assert!(discovery.is_some(), "init: missing protocol discovery"); - assert!(transmitter.is_some(), "init: missing protocol transmitter"); - - if let Some(ping) = ping { - metas.push(ping.build_meta(PING_PROTOCOL_ID.into())); - } - - if let Some(identify) = identify { - metas.push(identify.build_meta(IDENTIFY_PROTOCOL_ID.into())); - } - - if let Some(discovery) = discovery { - metas.push(discovery.build_meta(DISCOVERY_PROTOCOL_ID.into())); - } - - if let Some(transmitter) = transmitter.as_ref() { - let transmitter = transmitter.clone(); - metas.push(transmitter.build_meta(TRANSMITTER_PROTOCOL_ID.into())); - } - - CoreProtocol { - metas, - transmitter: transmitter.unwrap(), - } + let ping = ping.expect("init: missing protocol ping"); + let identify = identify.expect("init: missing protocol identify"); + let discovery = discovery.expect("init: missing protocol discovery"); + let transmitter = transmitter.expect("init: missing protocol transmitter"); + + metas.push(ping.build_meta(PING_PROTOCOL_ID.into())); + metas.push(identify.build_meta(IDENTIFY_PROTOCOL_ID.into())); + metas.push(discovery.build_meta(DISCOVERY_PROTOCOL_ID.into())); + metas.push( + transmitter + .clone() + .build_meta(TRANSMITTER_PROTOCOL_ID.into()), + ); + + CoreProtocol { metas, transmitter } } } diff --git a/core/network/src/protocols/discovery.rs b/core/network/src/protocols/discovery.rs index 5fb254686..da6e5471c 100644 --- a/core/network/src/protocols/discovery.rs +++ b/core/network/src/protocols/discovery.rs @@ -35,8 +35,8 @@ impl Discovery { #[cfg(not(feature = "global_ip_only"))] log::info!("turn off global ip only"); - let address_manager = AddressManager::new(peer_mgr, event_tx); - let behaviour = DiscoveryBehaviour::new(address_manager, Some(sync_interval)); + let address_manager = AddressManager::new(peer_mgr.clone(), event_tx); + let behaviour = DiscoveryBehaviour::new(address_manager, peer_mgr, Some(sync_interval)); Discovery(DiscoveryProtocol::new(behaviour)) } diff --git a/core/network/src/protocols/discovery/behaviour.rs b/core/network/src/protocols/discovery/behaviour.rs index ed71ce2c6..4948b643b 100644 --- a/core/network/src/protocols/discovery/behaviour.rs +++ b/core/network/src/protocols/discovery/behaviour.rs @@ -1,29 +1,23 @@ -use super::{ - addr::{AddressManager, ConnectableAddr, DEFAULT_MAX_KNOWN}, - message::{DiscoveryMessage, Nodes}, - substream::{RemoteAddress, Substream, SubstreamKey, SubstreamValue}, -}; - -use futures::{ - channel::mpsc::{channel, Receiver, Sender}, - stream::FusedStream, - Stream, -}; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; + +use futures::channel::mpsc::{channel, Receiver, Sender}; +use futures::stream::FusedStream; +use futures::Stream; use log::debug; use rand::seq::SliceRandom; -use tentacle::{ - multiaddr::Multiaddr, - utils::{is_reachable, multiaddr_to_socketaddr}, - SessionId, -}; +use tentacle::multiaddr::Multiaddr; +use tentacle::utils::{is_reachable, multiaddr_to_socketaddr}; +use tentacle::SessionId; use tokio::time::Interval; -use std::{ - collections::{HashMap, HashSet, VecDeque}, - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; +use crate::peer_manager::PeerManagerHandle; + +use super::addr::{AddressManager, ConnectableAddr, DEFAULT_MAX_KNOWN}; +use super::message::{DiscoveryMessage, Nodes}; +use super::substream::{RemoteAddress, Substream, SubstreamKey, SubstreamValue}; const CHECK_INTERVAL: Duration = Duration::from_secs(3); @@ -34,6 +28,10 @@ pub struct DiscoveryBehaviour { // Address Manager addr_mgr: AddressManager, + // TODO: Remove address manager + // Peer Manager + peer_mgr: PeerManagerHandle, + // The Nodes not yet been yield pending_nodes: VecDeque<(SubstreamKey, SessionId, Nodes)>, @@ -55,17 +53,30 @@ pub struct DiscoveryBehaviour { #[derive(Clone)] pub struct DiscoveryBehaviourHandle { pub substream_sender: Sender, + pub peer_mgr: PeerManagerHandle, +} + +impl DiscoveryBehaviourHandle { + pub fn contains_session(&self, session_id: SessionId) -> bool { + self.peer_mgr.contains_session(session_id) + } } impl DiscoveryBehaviour { /// Query cycle means checking and synchronizing the cycle time of the /// currently connected node, default is 24 hours - pub fn new(addr_mgr: AddressManager, query_cycle: Option) -> DiscoveryBehaviour { + pub fn new( + addr_mgr: AddressManager, + peer_mgr: PeerManagerHandle, + query_cycle: Option, + ) -> DiscoveryBehaviour { let (substream_sender, substream_receiver) = channel(8); + DiscoveryBehaviour { check_interval: None, max_known: DEFAULT_MAX_KNOWN, addr_mgr, + peer_mgr, pending_nodes: VecDeque::default(), substreams: HashMap::default(), substream_sender, @@ -78,6 +89,7 @@ impl DiscoveryBehaviour { pub fn handle(&self) -> DiscoveryBehaviourHandle { DiscoveryBehaviourHandle { substream_sender: self.substream_sender.clone(), + peer_mgr: self.peer_mgr.clone(), } } diff --git a/core/network/src/protocols/discovery/protocol.rs b/core/network/src/protocols/discovery/protocol.rs index 500725340..d8eda0d4d 100644 --- a/core/network/src/protocols/discovery/protocol.rs +++ b/core/network/src/protocols/discovery/protocol.rs @@ -1,21 +1,15 @@ -use super::{ - behaviour::{DiscoveryBehaviour, DiscoveryBehaviourHandle}, - substream::Substream, -}; +use std::collections::HashMap; -use futures::{ - channel::mpsc::{channel, Sender}, - future::FutureExt, - stream::StreamExt, -}; +use futures::channel::mpsc::{channel, Sender}; +use futures::stream::StreamExt; +use futures::FutureExt; use log::{debug, warn}; -use tentacle::{ - context::{ProtocolContext, ProtocolContextMutRef}, - traits::ServiceProtocol, - SessionId, -}; +use tentacle::context::{ProtocolContext, ProtocolContextMutRef}; +use tentacle::traits::ServiceProtocol; +use tentacle::SessionId; -use std::collections::HashMap; +use super::behaviour::{DiscoveryBehaviour, DiscoveryBehaviourHandle}; +use super::substream::Substream; pub struct DiscoveryProtocol { behaviour: Option, @@ -66,6 +60,21 @@ impl ServiceProtocol for DiscoveryProtocol { session.id, session.address, session.ty ); + if !self.behaviour_handle.contains_session(session.id) { + let _ = context.close_protocol(session.id, context.proto_id()); + return; + } + + let peer_id = match context.session.remote_pubkey.as_ref() { + Some(pubkey) => pubkey.peer_id(), + None => { + log::warn!("peer connection must be encrypted"); + let _ = context.disconnect(context.session.id); + return; + } + }; + crate::protocols::OpenedProtocols::register(peer_id, context.proto_id()); + let (sender, receiver) = channel(8); self.discovery_senders.insert(session.id, sender); let substream = Substream::new(context, receiver); diff --git a/core/network/src/protocols/identify.rs b/core/network/src/protocols/identify.rs index b6b9576ab..d9dc3a023 100644 --- a/core/network/src/protocols/identify.rs +++ b/core/network/src/protocols/identify.rs @@ -1,23 +1,35 @@ mod behaviour; mod common; +mod identification; mod message; mod protocol; -use self::protocol::IdentifyProtocol; -use behaviour::IdentifyBehaviour; -use crate::{event::PeerManagerEvent, peer_manager::PeerManagerHandle}; +#[cfg(test)] +mod tests; + +use std::sync::Arc; use futures::channel::mpsc::UnboundedSender; -use tentacle::{ - builder::MetaBuilder, - service::{ProtocolHandle, ProtocolMeta}, - ProtocolId, -}; +use tentacle::builder::MetaBuilder; +use tentacle::secio::PeerId; +use tentacle::service::{ProtocolHandle, ProtocolMeta}; +use tentacle::ProtocolId; + +use crate::event::PeerManagerEvent; +use crate::peer_manager::PeerManagerHandle; + +use self::protocol::IdentifyProtocol; +use behaviour::IdentifyBehaviour; + +pub use self::identification::WaitIdentification; +pub use self::protocol::Error; pub const NAME: &str = "chain_identify"; -pub const SUPPORT_VERSIONS: [&str; 1] = ["0.1"]; +pub const SUPPORT_VERSIONS: [&str; 1] = ["0.2"]; -pub struct Identify(IdentifyProtocol); +pub struct Identify { + behaviour: Arc, +} impl Identify { pub fn new(peer_mgr: PeerManagerHandle, event_tx: UnboundedSender) -> Self { @@ -26,16 +38,41 @@ impl Identify { #[cfg(not(feature = "global_ip_only"))] log::info!("turn off global ip only"); - let behaviour = IdentifyBehaviour::new(peer_mgr, event_tx); - Identify(IdentifyProtocol::new(behaviour)) + let behaviour = Arc::new(IdentifyBehaviour::new(peer_mgr, event_tx)); + Identify { behaviour } } + #[cfg(not(test))] pub fn build_meta(self, protocol_id: ProtocolId) -> ProtocolMeta { + let behaviour = self.behaviour; + MetaBuilder::new() .id(protocol_id) .name(name!(NAME)) .support_versions(support_versions!(SUPPORT_VERSIONS)) - .service_handle(move || ProtocolHandle::Callback(Box::new(self.0))) + .session_handle(move || { + ProtocolHandle::Callback(Box::new(IdentifyProtocol::new(Arc::clone(&behaviour)))) + }) .build() } + + #[cfg(test)] + pub fn build_meta(self, protocol_id: ProtocolId) -> ProtocolMeta { + let _ = self.behaviour; + + MetaBuilder::new() + .id(protocol_id) + .name(name!(NAME)) + .support_versions(support_versions!(SUPPORT_VERSIONS)) + .session_handle(move || ProtocolHandle::Callback(Box::new(IdentifyProtocol::new()))) + .build() + } + + pub fn wait_identified(peer_id: PeerId) -> WaitIdentification { + IdentifyProtocol::wait(peer_id) + } + + pub fn wait_failed(peer_id: &PeerId, error: String) { + IdentifyProtocol::wait_failed(peer_id, error) + } } diff --git a/core/network/src/protocols/identify/behaviour.rs b/core/network/src/protocols/identify/behaviour.rs index dc010b012..83e005355 100644 --- a/core/network/src/protocols/identify/behaviour.rs +++ b/core/network/src/protocols/identify/behaviour.rs @@ -1,75 +1,17 @@ -use super::common::reachable; -use crate::{event::PeerManagerEvent, peer_manager::PeerManagerHandle}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use futures::channel::mpsc::UnboundedSender; -use log::{debug, trace, warn}; -use tentacle::{ - context::{ProtocolContextMutRef, SessionContext}, - multiaddr::Multiaddr, - secio::PeerId, - service::SessionType, -}; - -use std::{ - sync::atomic::{AtomicBool, Ordering}, - sync::Arc, - time::{Duration, Instant}, -}; - -pub const MAX_ADDRS: usize = 10; - -/// The misbehavior to report to underlying peer storage -pub enum Misbehavior { - /// Repeat send listen addresses - DuplicateListenAddrs, - /// Repeat send observed address - DuplicateObservedAddr, - /// Timeout reached - Timeout, - /// Remote peer send invalid data - InvalidData, - /// Send too many addresses in listen addresses - TooManyAddresses(usize), -} - -/// Misbehavior report result -pub enum MisbehaveResult { - /// Continue to run - Continue, - /// Disconnect this peer - Disconnect, -} - -impl MisbehaveResult { - pub fn is_disconnect(&self) -> bool { - match self { - MisbehaveResult::Disconnect => true, - _ => false, - } - } -} +use tentacle::multiaddr::Multiaddr; +use tentacle::secio::PeerId; +use tentacle::service::SessionType; -pub struct RemoteInfo { - pub peer_id: PeerId, - pub session: SessionContext, - pub connected_at: Instant, - pub timeout: Duration, - pub listen_addrs: Option>, - pub observed_addr: Option, -} +use crate::event::PeerManagerEvent; +use crate::peer_manager::PeerManagerHandle; -impl RemoteInfo { - pub fn new(peer_id: PeerId, session: SessionContext, timeout: Duration) -> RemoteInfo { - RemoteInfo { - peer_id, - session, - connected_at: Instant::now(), - timeout, - listen_addrs: None, - observed_addr: None, - } - } -} +use super::common::reachable; +use super::message; +use super::protocol::StateContext; #[derive(Clone)] struct AddrReporter { @@ -86,13 +28,13 @@ impl AddrReporter { } // TODO: upstream heart-beat check - pub fn report(&mut self, event: PeerManagerEvent) { + pub fn report(&self, event: PeerManagerEvent) { if self.shutdown.load(Ordering::SeqCst) { return; } if self.inner.unbounded_send(event).is_err() { - debug!("network: discovery: peer manager offline"); + log::debug!("network: discovery: peer manager offline"); self.shutdown.store(true, Ordering::SeqCst); } @@ -105,6 +47,8 @@ pub struct IdentifyBehaviour { addr_reporter: AddrReporter, } +// Allow dead code for cfg(test) +#[allow(dead_code)] impl IdentifyBehaviour { pub fn new(peer_mgr: PeerManagerHandle, event_tx: UnboundedSender) -> Self { let addr_reporter = AddrReporter::new(event_tx); @@ -115,113 +59,109 @@ impl IdentifyBehaviour { } } - pub fn identify(&mut self) -> &str { - "Identify message" + pub fn chain_id(&self) -> String { + self.peer_mgr.chain_id().as_ref().as_hex() } - pub fn process_listens( - &mut self, - info: &mut RemoteInfo, - listens: Vec, - ) -> MisbehaveResult { - if info.listen_addrs.is_some() { - debug!("remote({:?}) repeat send observed address", info.peer_id); - self.misbehave(&info.peer_id, Misbehavior::DuplicateListenAddrs) - } else if listens.len() > MAX_ADDRS { - self.misbehave(&info.peer_id, Misbehavior::TooManyAddresses(listens.len())) - } else { - trace!("received listen addresses: {:?}", listens); - let reachable_addrs = listens.into_iter().filter(reachable).collect::>(); - - info.listen_addrs = Some(reachable_addrs.clone()); - self.add_remote_listen_addrs(&info.peer_id, reachable_addrs); + pub fn local_listen_addrs(&self) -> Vec { + let addrs = self.peer_mgr.listen_addrs(); + let reachable_addrs = addrs.into_iter().filter(reachable); - MisbehaveResult::Continue - } + reachable_addrs.take(message::MAX_LISTEN_ADDRS).collect() } - pub fn process_observed( - &mut self, - info: &mut RemoteInfo, - observed: Option, - ) -> MisbehaveResult { - if info.observed_addr.is_some() { - debug!("remote({:?}) repeat send listen addresses", info.peer_id); - self.misbehave(&info.peer_id, Misbehavior::DuplicateObservedAddr) - } else { - let observed = match observed { - Some(addr) => addr, - None => { - warn!("observed is none from peer {:?}", info.peer_id); - return MisbehaveResult::Disconnect; - } - }; - - trace!("received observed address: {}", observed); - let mut unobservable = |info: &mut RemoteInfo, observed| -> bool { - self.add_observed_addr(&info.peer_id, observed, info.session.ty) - .is_disconnect() - }; + pub fn send_identity(&self, context: &StateContext) { + let address_info = { + let listen_addrs = self.local_listen_addrs(); + let observed_addr = context.observed_addr(); + message::AddressInfo::new(listen_addrs, observed_addr) + }; - if reachable(&observed) && unobservable(info, observed.clone()) { - return MisbehaveResult::Disconnect; + let identity = { + let msg = message::Identity::new(self.chain_id(), address_info); + match msg.into_bytes() { + Ok(msg) => msg, + Err(err) => { + log::warn!("encode identity msg failed {}", err); + context.disconnect(); + return; + } } + }; - info.observed_addr = Some(observed); - MisbehaveResult::Continue - } + context.send_message(identity); } - pub fn received_identify( - &mut self, - _context: &mut ProtocolContextMutRef, - _identify: &[u8], - ) -> MisbehaveResult { - MisbehaveResult::Continue + pub fn send_ack(&self, context: &StateContext) { + let address_info = { + let listen_addrs = self.local_listen_addrs(); + let observed_addr = context.observed_addr(); + message::AddressInfo::new(listen_addrs, observed_addr) + }; + + let acknowledge = { + let msg = message::Acknowledge::new(address_info); + match msg.into_bytes() { + Ok(msg) => msg, + Err(err) => { + log::warn!("encode acknowledge msg failed {}", err); + context.disconnect(); + return; + } + } + }; + + context.send_message(acknowledge); } - pub fn local_listen_addrs(&self) -> Vec { - self.peer_mgr.listen_addrs() + pub fn verify_remote_identity( + &self, + identity: &message::Identity, + ) -> Result<(), super::protocol::Error> { + if identity.chain_id != self.chain_id() { + Err(super::protocol::Error::WrongChainId) + } else { + Ok(()) + } } - pub fn add_remote_listen_addrs(&mut self, peer_id: &PeerId, addrs: Vec) { - debug!("add remote listen {:?} addrs {:?}", peer_id, addrs); + pub fn process_listens(&self, context: &StateContext, listens: Vec) { + let peer_id = &context.remote_peer.id; + log::debug!("listen addresses: {:?}", listens); + let reachable_addrs = listens.into_iter().filter(reachable).collect::>(); let identified_addrs = PeerManagerEvent::IdentifiedAddrs { - pid: peer_id.to_owned(), - addrs, + pid: peer_id.to_owned(), + addrs: reachable_addrs, }; self.addr_reporter.report(identified_addrs); } + pub fn process_observed(&self, context: &StateContext, observed: Multiaddr) { + let peer_id = &context.remote_peer.id; + let session_type = context.session_context.ty; + log::debug!("observed addr {:?} from {}", observed, context.remote_peer); + + let unobservable = |observed| -> bool { + self.add_observed_addr(peer_id, observed, session_type) + .is_err() + }; + + if reachable(&observed) && unobservable(observed.clone()) { + log::warn!("unobservable {} from {}", observed, context.remote_peer); + context.disconnect(); + } + } + pub fn add_observed_addr( - &mut self, + &self, peer: &PeerId, addr: Multiaddr, ty: SessionType, - ) -> MisbehaveResult { - debug!("add observed: {:?}, addr {:?}, ty: {:?}", peer, addr, ty); + ) -> Result<(), ()> { + log::debug!("add observed: {:?}, addr {:?}, ty: {:?}", peer, addr, ty); // Noop right now - MisbehaveResult::Continue - } - - /// Report misbehavior - pub fn misbehave(&mut self, peer: &PeerId, kind: Misbehavior) -> MisbehaveResult { - match kind { - Misbehavior::DuplicateListenAddrs => { - debug!("peer {:?} misbehave: duplicatelisten addrs", peer) - } - Misbehavior::DuplicateObservedAddr => { - debug!("peer {:?} misbehave: duplicate observed addr", peer) - } - Misbehavior::TooManyAddresses(size) => { - debug!("peer {:?} misbehave: too many address {}", peer, size) - } - Misbehavior::InvalidData => debug!("peer {:?} misbehave: invalid data", peer), - Misbehavior::Timeout => debug!("peer {:?} misbehave: timeout", peer), - } - - MisbehaveResult::Disconnect + Ok(()) } } diff --git a/core/network/src/protocols/identify/identification.rs b/core/network/src/protocols/identify/identification.rs new file mode 100644 index 000000000..75bb1d453 --- /dev/null +++ b/core/network/src/protocols/identify/identification.rs @@ -0,0 +1,181 @@ +use std::borrow::Borrow; +use std::collections::HashSet; +use std::future::Future; +use std::hash::{Hash, Hasher}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use parking_lot::Mutex; + +type Index = usize; + +pub struct WaitIdentification { + idx: Index, + ident_status: Arc>, +} + +impl WaitIdentification { + fn new(ident_status: Arc>) -> Self { + WaitIdentification { + idx: usize::MAX, + ident_status, + } + } +} + +impl Future for WaitIdentification { + type Output = Result<(), super::protocol::Error>; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let insert_idx = { + let idx = self.idx; + match &mut *self.ident_status.lock() { + IdentificationStatus::Done(ret) => return Poll::Ready(ret.to_owned()), + IdentificationStatus::Pending(_) if idx != usize::MAX => return Poll::Pending, + IdentificationStatus::Pending(wakerset) => wakerset.insert(ctx.waker().to_owned()), + } + }; + + self.idx = insert_idx; + Poll::Pending + } +} + +impl Drop for WaitIdentification { + fn drop(&mut self) { + if let IdentificationStatus::Pending(wakerset) = &mut *self.ident_status.lock() { + wakerset.remove(self.idx); + } + } +} + +pub struct Identification { + status: Arc>, +} + +impl Identification { + pub(crate) fn new() -> Self { + Identification { + status: Default::default(), + } + } + + pub fn wait(&self) -> WaitIdentification { + WaitIdentification::new(Arc::clone(&self.status)) + } + + pub fn pass(&self) { + self.done(Ok(())) + } + + pub fn failed(&self, error: super::protocol::Error) { + self.done(Err(error)) + } + + fn fail_if_not_done(&self) { + { + let status = self.status.lock(); + if let IdentificationStatus::Done(_) = &*status { + return; + } + } + + self.failed(super::protocol::Error::WaitFutDropped) + } + + fn done(&self, ret: Result<(), super::protocol::Error>) { + let mut status = self.status.lock(); + + if let IdentificationStatus::Pending(wakerset) = + std::mem::replace(&mut *status, IdentificationStatus::Done(ret)) + { + wakerset.wake() + } + } +} + +impl Drop for Identification { + fn drop(&mut self) { + self.fail_if_not_done() + } +} + +struct IndexedWaker { + idx: Index, + waker: Waker, +} + +impl IndexedWaker { + fn wake(self) { + self.waker.wake() + } +} + +impl Borrow for IndexedWaker { + fn borrow(&self) -> &Index { + &self.idx + } +} + +impl PartialEq for IndexedWaker { + fn eq(&self, other: &IndexedWaker) -> bool { + self.idx == other.idx + } +} + +impl Eq for IndexedWaker {} + +impl Hash for IndexedWaker { + fn hash(&self, state: &mut H) { + self.idx.hash(state) + } +} + +struct WakerSet { + id: Index, + wakers: HashSet, +} + +impl WakerSet { + fn new() -> WakerSet { + WakerSet { + id: 0, + wakers: HashSet::new(), + } + } + + fn insert(&mut self, waker: Waker) -> Index { + debug_assert!(self.id != std::usize::MAX); + self.id += 1; + + let indexed_waker = IndexedWaker { + idx: self.id, + waker, + }; + + self.wakers.insert(indexed_waker); + self.id + } + + fn remove(&mut self, idx: Index) { + self.wakers.remove(&idx); + } + + fn wake(self) { + for waker in self.wakers { + waker.wake() + } + } +} + +enum IdentificationStatus { + Pending(WakerSet), + Done(Result<(), super::protocol::Error>), +} + +impl Default for IdentificationStatus { + fn default() -> Self { + IdentificationStatus::Pending(WakerSet::new()) + } +} diff --git a/core/network/src/protocols/identify/message.rs b/core/network/src/protocols/identify/message.rs index a86ea2768..0ccba547a 100644 --- a/core/network/src/protocols/identify/message.rs +++ b/core/network/src/protocols/identify/message.rs @@ -1,25 +1,62 @@ +use std::convert::TryFrom; + +use derive_more::Display; use prost::{EncodeError, Message}; use protocol::{Bytes, BytesMut}; use tentacle::multiaddr::Multiaddr; -use std::convert::TryFrom; +pub const MAX_LISTEN_ADDRS: usize = 10; + +#[derive(Debug, Display)] +pub enum Error { + #[display(fmt = "too many listen addrs")] + TooManyListenAddrs, + + #[display(fmt = "no observed addrs")] + NoObservedAddr, + + #[display(fmt = "no addr info")] + NoAddrInfo, +} + +pub trait AddressInfoMessage { + fn validate(&self) -> Result<(), self::Error>; + fn listen_addrs(&self) -> Vec; + fn observed_addr(&self) -> Option; +} + +impl AddressInfoMessage for Option { + fn listen_addrs(&self) -> Vec { + self.as_ref() + .map(|ai| ai.listen_addrs()) + .unwrap_or_else(Vec::new) + } + + fn observed_addr(&self) -> Option { + self.as_ref().map(|ai| ai.observed_addr()).flatten() + } + + fn validate(&self) -> Result<(), self::Error> { + match self.as_ref() { + Some(addr_info) => addr_info.validate(), + None => Err(self::Error::NoAddrInfo), + } + } +} -#[derive(Clone, PartialEq, Eq, Message)] -pub struct IdentifyMessage { +#[derive(Message)] +pub struct AddressInfo { #[prost(bytes, repeated, tag = "1")] pub listen_addrs: Vec>, #[prost(bytes, tag = "2")] pub observed_addr: Vec, - #[prost(string, tag = "3")] - pub identify: String, } -impl IdentifyMessage { - pub fn new(listen_addrs: Vec, observed_addr: Multiaddr, identify: String) -> Self { - IdentifyMessage { - listen_addrs: listen_addrs.into_iter().map(|addr| addr.to_vec()).collect(), +impl AddressInfo { + pub fn new(listen_addrs: Vec, observed_addr: Multiaddr) -> Self { + AddressInfo { + listen_addrs: listen_addrs.into_iter().map(|addr| addr.to_vec()).collect(), observed_addr: observed_addr.to_vec(), - identify, } } @@ -33,10 +70,124 @@ impl IdentifyMessage { Multiaddr::try_from(self.observed_addr.clone()).ok() } + pub fn validate(&self) -> Result<(), self::Error> { + if self.listen_addrs.len() > MAX_LISTEN_ADDRS { + return Err(self::Error::TooManyListenAddrs); + } + + if self.observed_addr().is_none() { + return Err(self::Error::NoObservedAddr); + } + + Ok(()) + } + + #[cfg(test)] + pub fn mock_valid() -> Self { + let listen_addr: Multiaddr = "/ip4/47.111.169.36/tcp/2000".parse().unwrap(); + let observed_addr: Multiaddr = "/ip4/47.111.169.36/tcp/2001".parse().unwrap(); + + AddressInfo { + listen_addrs: vec![listen_addr.to_vec()], + observed_addr: observed_addr.to_vec(), + } + } + + #[cfg(test)] + pub fn mock_invalid() -> Self { + AddressInfo { + listen_addrs: vec![], + observed_addr: b"xxx".to_vec(), + } + } +} + +#[derive(Message)] +pub struct Identity { + #[prost(string, tag = "1")] + pub chain_id: String, + #[prost(message, tag = "2")] + pub addr_info: Option, +} + +impl Identity { + pub fn new(chain_id: String, addr_info: AddressInfo) -> Self { + Identity { + chain_id, + addr_info: Some(addr_info), + } + } + + pub fn validate(&self) -> Result<(), self::Error> { + self.addr_info.validate() + } + + pub fn into_bytes(self) -> Result { + let mut buf = BytesMut::with_capacity(self.encoded_len()); + self.encode(&mut buf)?; + + Ok(buf.freeze()) + } + + #[cfg(test)] + pub fn mock_valid() -> Self { + use protocol::types::Hash; + + Identity { + chain_id: Hash::digest(Bytes::from_static(b"hello")).as_hex(), + addr_info: Some(AddressInfo::mock_valid()), + } + } + + #[cfg(test)] + pub fn mock_invalid() -> Self { + use protocol::types::Hash; + + let identity = Identity { + chain_id: Hash::digest(Bytes::from_static(b"hello")).as_hex(), + addr_info: Some(AddressInfo::mock_invalid()), + }; + assert!(identity.validate().is_err()); + + identity + } +} + +#[derive(Message)] +pub struct Acknowledge { + #[prost(message, tag = "1")] + pub addr_info: Option, +} + +impl Acknowledge { + pub fn new(addr_info: AddressInfo) -> Self { + Acknowledge { + addr_info: Some(addr_info), + } + } + + pub fn validate(&self) -> Result<(), self::Error> { + self.addr_info.validate() + } + pub fn into_bytes(self) -> Result { let mut buf = BytesMut::with_capacity(self.encoded_len()); self.encode(&mut buf)?; Ok(buf.freeze()) } + + #[cfg(test)] + pub fn mock_valid() -> Self { + Acknowledge { + addr_info: Some(AddressInfo::mock_valid()), + } + } + + #[cfg(test)] + pub fn mock_invalid() -> Self { + Acknowledge { + addr_info: Some(AddressInfo::mock_invalid()), + } + } } diff --git a/core/network/src/protocols/identify/protocol.rs b/core/network/src/protocols/identify/protocol.rs index 8da50ee93..b9cbdd7b0 100644 --- a/core/network/src/protocols/identify/protocol.rs +++ b/core/network/src/protocols/identify/protocol.rs @@ -1,171 +1,522 @@ -use super::{ - behaviour::{IdentifyBehaviour, Misbehavior, RemoteInfo, MAX_ADDRS}, - common::reachable, - message::IdentifyMessage, -}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; -use log::{debug, error, trace, warn}; +use derive_more::Display; +use futures::future::{self, AbortHandle}; +use futures_timer::Delay; +use lazy_static::lazy_static; +use parking_lot::RwLock; use prost::Message; -use tentacle::{ - context::{ProtocolContext, ProtocolContextMutRef}, - multiaddr::{Multiaddr, Protocol}, - traits::ServiceProtocol, - SessionId, -}; - -use std::{ - collections::HashMap, - time::{Duration, Instant}, -}; - -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(8); -const CHECK_TIMEOUT_INTERVAL: Duration = Duration::from_secs(1); -const CHECK_TIMEOUT_TOKEN: u64 = 100; +use protocol::Bytes; +use tentacle::multiaddr::{Multiaddr, Protocol}; +use tentacle::secio::PeerId; +use tentacle::service::{SessionType, TargetProtocol}; +use tentacle::traits::SessionProtocol; +use tentacle::{ProtocolId, SessionId}; + +#[cfg(test)] +use crate::test::mock::{ServiceControl, SessionContext}; +#[cfg(not(test))] +use tentacle::context::{ProtocolContextMutRef, SessionContext}; +#[cfg(not(test))] +use tentacle::service::ServiceControl; + +#[cfg(not(test))] +use super::behaviour::IdentifyBehaviour; +#[cfg(test)] +use super::tests::MockIdentifyBehaviour; + +use super::identification::{Identification, WaitIdentification}; +use super::message::{Acknowledge, AddressInfoMessage, Identity}; + +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(8); +pub const MAX_MESSAGE_SIZE: usize = 5 * 1000; // 5KB + +lazy_static! { + // NOTE: Use peer id here because trust metric integrated test run in one process + static ref PEER_IDENTIFICATION_BACKLOG: RwLock> = + RwLock::new(HashMap::new()); +} + +#[derive(Debug, Display, Clone)] +pub enum Error { + #[display(fmt = "wrong chain id")] + WrongChainId, + + #[display(fmt = "timeout")] + Timeout, + + #[display(fmt = "exceed max message size")] + ExceedMaxMessageSize, + + #[display(fmt = "decode indentity failed")] + DecodeIdentityFailed, + + #[display(fmt = "decode ack failed")] + DecodeAckFailed, + + #[display(fmt = "{}", _0)] + InvalidMessage(String), + + #[display(fmt = "wait future dropped")] + WaitFutDropped, + + #[display(fmt = "disconnected")] + Disconnected, + + #[display(fmt = "{}", _0)] + Other(String), +} + +// Wrap ProtocolContextMutRef for easy mock and test +#[cfg(not(test))] +pub struct IdentifyProtocolContext<'a>(ProtocolContextMutRef<'a>); +#[cfg(test)] +pub struct IdentifyProtocolContext<'a>(pub &'a crate::test::mock::ProtocolContext); + +#[derive(Debug, Display)] +#[display(fmt = "peer {:?} addr {:?}", id, addr)] +pub struct RemotePeer { + pub id: PeerId, + pub sid: SessionId, + pub addr: Multiaddr, +} + +pub struct NoEncryption; + +impl RemotePeer { + pub fn from_proto_context( + proto_context: &IdentifyProtocolContext, + ) -> Result { + match proto_context.0.session.remote_pubkey.as_ref() { + None => Err(NoEncryption), + Some(pubkey) => { + let remote_peer = RemotePeer { + id: pubkey.peer_id(), + sid: proto_context.0.session.id, + addr: proto_context.0.session.address.to_owned(), + }; + + Ok(remote_peer) + } + } + } +} + +pub struct StateContext { + pub remote_peer: Arc, + pub proto_id: ProtocolId, + pub service_control: ServiceControl, + pub session_context: SessionContext, + pub timeout_abort_handle: Option, +} + +impl StateContext { + pub fn from_proto_context( + proto_context: &IdentifyProtocolContext, + ) -> Result { + let remote_peer = RemotePeer::from_proto_context(proto_context)?; + let state_context = StateContext { + remote_peer: Arc::new(remote_peer), + proto_id: proto_context.0.proto_id(), + service_control: proto_context.0.control().clone(), + session_context: proto_context.0.session.clone(), + timeout_abort_handle: None, + }; + + Ok(state_context) + } + + pub fn observed_addr(&self) -> Multiaddr { + let remote_addr = self.session_context.address.iter(); + + remote_addr + .filter(|proto| match proto { + Protocol::P2P(_) => false, + _ => true, + }) + .collect() + } + + pub fn send_message(&self, msg: Bytes) { + if let Err(err) = + self.service_control + .quick_send_message_to(self.remote_peer.sid, self.proto_id, msg) + { + log::warn!( + "internal error: quick send message to {} failed {}", + self.remote_peer, + err + ); + } + } + + pub fn disconnect(&self) { + let _ = self.service_control.disconnect(self.remote_peer.sid); + } + + pub fn open_protocols(&self) { + if let Err(err) = self + .service_control + .open_protocols(self.remote_peer.sid, TargetProtocol::All) + { + log::warn!("open protocols to peer {} failed {}", self.remote_peer, err); + self.disconnect() + } + } + + pub fn set_open_protocols_timeout(&mut self, timeout: Duration) { + let service_control = self.service_control.clone(); + let remote_peer = Arc::clone(&self.remote_peer); + + tokio::spawn(async move { + Delay::new(timeout).await; + + if crate::protocols::OpenedProtocols::is_all_opened(&remote_peer.id) { + return; + } + + log::warn!("peer {} open protocols timeout, disconnect it", remote_peer); + let _ = service_control.disconnect(remote_peer.sid); + }); + } + + pub fn set_timeout(&mut self, description: &'static str, timeout: Duration) { + let service_control = self.service_control.clone(); + let remote_peer = Arc::clone(&self.remote_peer); + + let (timeout, timeout_abort_handle) = future::abortable(async move { + Delay::new(timeout).await; + + log::warn!( + "{} timeout from peer {}, disconnect it", + description, + remote_peer, + ); + + finish_identify(&remote_peer, Err(self::Error::Timeout)); + let _ = service_control.disconnect(remote_peer.sid); + }); + + self.timeout_abort_handle = Some(timeout_abort_handle); + tokio::spawn(timeout); + } + + pub fn cancel_timeout(&self) { + if let Some(timeout) = self.timeout_abort_handle.as_ref() { + timeout.abort() + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Display)] +pub enum ClientProcedure { + #[display(fmt = "client wait for server identity acknowledge")] + WaitAck, + + #[display(fmt = "client open other protocols")] + OpenOtherProtocols, + + #[display(fmt = "server failed identification")] + Failed, +} + +#[derive(Debug, Clone, PartialEq, Eq, Display)] +pub enum ServerProcedure { + #[display(fmt = "server wait for client identity")] + WaitIdentity, + + #[display(fmt = "server wait for client open protocols")] + WaitOpenProtocols, // After accept session + + #[display(fmt = "client failed identification")] + Failed, +} + +pub enum State { + SessionProtocolInited, + FailedWithoutEncryption, + FailedWithExceedMsgSize, + ClientNegotiate { + procedure: ClientProcedure, + context: StateContext, + }, + ServerNegotiate { + procedure: ServerProcedure, + context: StateContext, + }, +} pub struct IdentifyProtocol { - remote_infos: HashMap, - behaviour: IdentifyBehaviour, + pub(crate) state: State, + #[cfg(not(test))] + behaviour: Arc, + #[cfg(test)] + pub(crate) behaviour: Arc, } impl IdentifyProtocol { - pub fn new(behaviour: IdentifyBehaviour) -> Self { + #[cfg(not(test))] + pub fn new(behaviour: Arc) -> Self { IdentifyProtocol { - remote_infos: HashMap::new(), + state: State::SessionProtocolInited, behaviour, } } -} -impl ServiceProtocol for IdentifyProtocol { - fn init(&mut self, context: &mut ProtocolContext) { - let proto_id = context.proto_id; + #[cfg(test)] + pub fn new() -> Self { + IdentifyProtocol { + state: State::SessionProtocolInited, + behaviour: Arc::new(MockIdentifyBehaviour::new()), + } + } - if let Err(e) = - context.set_service_notify(proto_id, CHECK_TIMEOUT_INTERVAL, CHECK_TIMEOUT_TOKEN) - { - warn!("identify start fail {}", e); + pub fn wait(peer_id: PeerId) -> WaitIdentification { + let mut backlog = PEER_IDENTIFICATION_BACKLOG.write(); + let identification = backlog.entry(peer_id).or_insert_with(Identification::new); + + identification.wait() + } + + pub fn wait_failed(peer_id: &PeerId, error: String) { + if let Some(identification) = { PEER_IDENTIFICATION_BACKLOG.write().remove(peer_id) } { + identification.failed(self::Error::Other(error)) } } - fn connected(&mut self, context: ProtocolContextMutRef, _version: &str) { - let session = context.session; - let remote_peer_id = match &session.remote_pubkey { - Some(pubkey) => pubkey.peer_id(), - None => { - error!("IdentifyProtocol require secio enabled!"); - let _ = context.disconnect(session.id); + pub fn on_connected(&mut self, protocol_context: &IdentifyProtocolContext) { + let mut state_context = match StateContext::from_proto_context(protocol_context) { + Ok(ctx) => ctx, + Err(_no) => { + // Without peer id, there's no way to register a wait identification.No + // need to clean it. + log::warn!( + "session from {:?} without encryption, disconnect it", + protocol_context.0.session.address + ); + + self.state = State::FailedWithoutEncryption; + let _ = protocol_context.0.disconnect(protocol_context.0.session.id); return; } }; + log::debug!("connected from {:?}", state_context.remote_peer); - trace!("IdentifyProtocol connected from {:?}", remote_peer_id); - let remote_info = RemoteInfo::new(remote_peer_id, session.clone(), DEFAULT_TIMEOUT); - self.remote_infos.insert(session.id, remote_info); - - let listen_addrs: Vec = self - .behaviour - .local_listen_addrs() - .into_iter() - .filter(reachable) - .take(MAX_ADDRS) - .collect(); - - let observed_addr = session - .address - .iter() - .filter(|proto| match proto { - Protocol::P2P(_) => false, - _ => true, - }) - .collect::(); + crate::protocols::OpenedProtocols::register( + state_context.remote_peer.id.to_owned(), + state_context.proto_id, + ); - let identify = self.behaviour.identify(); - let msg = match IdentifyMessage::new(listen_addrs, observed_addr, identify.to_owned()) - .into_bytes() - { - Ok(msg) => msg, - Err(err) => { - warn!("encode {}", err); - return; + match protocol_context.0.session.ty { + SessionType::Inbound => { + state_context.set_timeout("wait client identity", DEFAULT_TIMEOUT); + + self.state = State::ServerNegotiate { + procedure: ServerProcedure::WaitIdentity, + context: state_context, + }; } + SessionType::Outbound => { + self.behaviour.send_identity(&state_context); + state_context.set_timeout("wait server ack", DEFAULT_TIMEOUT); + + self.state = State::ClientNegotiate { + procedure: ClientProcedure::WaitAck, + context: state_context, + }; + } + } + } + + pub fn on_disconnected(&mut self, protocol_context: &IdentifyProtocolContext) { + // Without peer id, there's no way to register a wait identification. No + // need to clean it. + let peer_id = match protocol_context.0.session.remote_pubkey.as_ref() { + Some(pubkey) => pubkey.peer_id(), + None => return, }; - if let Err(err) = context.quick_send_message(msg) { - warn!("quick send message {}", err); + // TODO: Remove from upper level + crate::protocols::OpenedProtocols::remove(&peer_id); + + if let Some(identification) = PEER_IDENTIFICATION_BACKLOG.write().remove(&peer_id) { + identification.failed(self::Error::Disconnected); } } - fn disconnected(&mut self, context: ProtocolContextMutRef) { - let info = self - .remote_infos - .remove(&context.session.id) - .expect("RemoteInfo must exists"); - trace!("IdentifyProtocol disconnected from {:?}", info.peer_id); - } - - fn received(&mut self, mut context: ProtocolContextMutRef, data: bytes::Bytes) { - let session = context.session; - - match IdentifyMessage::decode(data) { - Ok(message) => { - let mut remote_info = self - .remote_infos - .get_mut(&context.session.id) - .expect("RemoteInfo must exists"); - let behaviour = &mut self.behaviour; - - // Need to interrupt processing, avoid pollution - if behaviour - .received_identify(&mut context, message.identify.as_bytes()) - .is_disconnect() - || behaviour - .process_listens(&mut remote_info, message.listen_addrs()) - .is_disconnect() - || behaviour - .process_observed(&mut remote_info, message.observed_addr()) - .is_disconnect() - { - let _ = context.disconnect(session.id); + pub fn on_received(&mut self, protocol_context: &IdentifyProtocolContext, data: Bytes) { + { + if data.len() > MAX_MESSAGE_SIZE { + let peer_id = match protocol_context.0.session.remote_pubkey.as_ref() { + Some(pubkey) => pubkey.peer_id(), + None => return, + }; + + if let Some(identification) = PEER_IDENTIFICATION_BACKLOG.write().remove(&peer_id) { + identification.failed(self::Error::ExceedMaxMessageSize); + self.state = State::FailedWithExceedMsgSize; + let _ = protocol_context.0.disconnect(protocol_context.0.session.id); + return; } } - Err(_) => { - let info = self - .remote_infos - .get(&session.id) - .expect("RemoteInfo must exists"); - - warn!( - "IdentifyProtocol received invalid data from {:?}", - info.peer_id - ); + } + + match &mut self.state { + State::ServerNegotiate { + ref mut procedure, + context, + } => match procedure { + ServerProcedure::WaitIdentity => { + let identity = match Identity::decode(data) { + Ok(ident) => ident, + Err(_) => { + log::warn!("received invalid identity from {:?}", context.remote_peer); + + finish_identify( + &context.remote_peer, + Err(self::Error::DecodeIdentityFailed), + ); + *procedure = ServerProcedure::Failed; + context.disconnect(); + return; + } + }; + context.cancel_timeout(); + + if let Err(err) = identity.validate() { + finish_identify( + &context.remote_peer, + Err(self::Error::InvalidMessage(err.to_string())), + ); + *procedure = ServerProcedure::Failed; + context.disconnect(); + return; + } + + if let Err(err) = self.behaviour.verify_remote_identity(&identity) { + finish_identify(&context.remote_peer, Err(err)); + *procedure = ServerProcedure::Failed; + context.disconnect(); + return; + } + + finish_identify(&context.remote_peer, Ok(())); + + let listen_addrs = identity.addr_info.listen_addrs(); + self.behaviour.process_listens(&context, listen_addrs); - if self - .behaviour - .misbehave(&info.peer_id, Misbehavior::InvalidData) - .is_disconnect() - { - let _ = context.disconnect(session.id); + if let Some(observed_addr) = identity.addr_info.observed_addr() { + self.behaviour.process_observed(&context, observed_addr); + } + + self.behaviour.send_ack(&context); + context.set_open_protocols_timeout(DEFAULT_TIMEOUT); + *procedure = ServerProcedure::WaitOpenProtocols; + } + ServerProcedure::Failed | ServerProcedure::WaitOpenProtocols => { + log::warn!( + "should not received any more message from {} after acked identity", + context.remote_peer + ); + context.disconnect(); } + }, + State::ClientNegotiate { + ref mut procedure, + context, + } => match procedure { + ClientProcedure::WaitAck => { + let acknowledge = match Acknowledge::decode(data) { + Ok(ack) => ack, + Err(_) => { + log::warn!("received invalid ack from {:?}", context.remote_peer); + + finish_identify( + &context.remote_peer, + Err(self::Error::DecodeAckFailed), + ); + *procedure = ClientProcedure::Failed; + context.disconnect(); + return; + } + }; + context.cancel_timeout(); + + if let Err(err) = acknowledge.validate() { + finish_identify( + &context.remote_peer, + Err(self::Error::InvalidMessage(err.to_string())), + ); + *procedure = ClientProcedure::Failed; + context.disconnect(); + return; + } + + finish_identify(&context.remote_peer, Ok(())); + + let listen_addrs = acknowledge.addr_info.listen_addrs(); + self.behaviour.process_listens(&context, listen_addrs); + + if let Some(observed_addr) = acknowledge.addr_info.observed_addr() { + self.behaviour.process_observed(&context, observed_addr); + } + + context.open_protocols(); + *procedure = ClientProcedure::OpenOtherProtocols; + } + ClientProcedure::OpenOtherProtocols | ClientProcedure::Failed => { + log::warn!( + "should not received any more message from {} after open protocols", + context.remote_peer + ); + context.disconnect(); + } + }, + _ => { + log::warn!( + "should not received message from {} out of negotiate state", + protocol_context.0.session.address + ); + let _ = protocol_context.0.disconnect(protocol_context.0.session.id); } } } +} - fn notify(&mut self, context: &mut ProtocolContext, _token: u64) { - let now = Instant::now(); - - for (session_id, info) in &self.remote_infos { - if (info.listen_addrs.is_none() || info.observed_addr.is_none()) - && (info.connected_at + info.timeout) <= now - { - debug!("{:?} receive identify message timeout", info.peer_id); - if self - .behaviour - .misbehave(&info.peer_id, Misbehavior::Timeout) - .is_disconnect() - { - let _ = context.disconnect(*session_id); - } - } +#[cfg(test)] +impl SessionProtocol for IdentifyProtocol {} + +#[cfg(not(test))] +impl SessionProtocol for IdentifyProtocol { + fn connected(&mut self, protocol_context: ProtocolContextMutRef, _version: &str) { + self.on_connected(&IdentifyProtocolContext(protocol_context)); + } + + fn disconnected(&mut self, protocol_context: ProtocolContextMutRef) { + self.on_disconnected(&IdentifyProtocolContext(protocol_context)); + } + + fn received(&mut self, protocol_context: ProtocolContextMutRef, data: bytes::Bytes) { + self.on_received(&IdentifyProtocolContext(protocol_context), data) + } +} + +fn finish_identify(peer: &RemotePeer, result: Result<(), self::Error>) { + let identification = match { PEER_IDENTIFICATION_BACKLOG.write().remove(&peer.id) } { + Some(ident) => ident, + None => { + log::debug!("peer {:?} identification has finished already", peer); + return; + } + }; + + match result { + Ok(()) => identification.pass(), + Err(err) => { + log::warn!("identification for peer {} failed: {}", peer, err); + identification.failed(err); } } } diff --git a/core/network/src/protocols/identify/tests.rs b/core/network/src/protocols/identify/tests.rs new file mode 100644 index 000000000..4a28685b4 --- /dev/null +++ b/core/network/src/protocols/identify/tests.rs @@ -0,0 +1,694 @@ +use std::time::Duration; + +use futures_timer::Delay; +use parking_lot::Mutex; +use protocol::Bytes; +use tentacle::multiaddr::Multiaddr; +use tentacle::service::{SessionType, TargetProtocol}; + +use super::message; +use super::protocol::{ + ClientProcedure, Error, IdentifyProtocol, IdentifyProtocolContext, ServerProcedure, State, + StateContext, MAX_MESSAGE_SIZE, +}; +use crate::test::mock::{ControlEvent, ProtocolContext}; + +const PROTOCOL_ID: usize = 2; +const SESSION_ID: usize = 2; + +#[derive(Debug, Clone)] +pub enum BehaviourEvent { + SendIdentity, + SendAck, + ProcessListen, + ProcessObserved, + VerifyRemoteIdentity, +} + +pub struct MockIdentifyBehaviour { + event: Mutex>, + skip_chain_id_verify: Mutex, +} + +impl MockIdentifyBehaviour { + pub fn new() -> Self { + MockIdentifyBehaviour { + event: Mutex::new(None), + skip_chain_id_verify: Mutex::new(true), + } + } + + pub fn event(&self) -> Option { + self.event.lock().clone() + } + + pub fn send_identity(&self, _: &StateContext) { + *self.event.lock() = Some(BehaviourEvent::SendIdentity) + } + + pub fn send_ack(&self, _: &StateContext) { + *self.event.lock() = Some(BehaviourEvent::SendAck) + } + + pub fn process_listens(&self, _: &StateContext, _listen_addrs: Vec) { + *self.event.lock() = Some(BehaviourEvent::ProcessListen) + } + + pub fn process_observed(&self, _: &StateContext, _observed_addr: Multiaddr) { + *self.event.lock() = Some(BehaviourEvent::ProcessObserved) + } + + pub fn verify_remote_identity(&self, _identity: &message::Identity) -> Result<(), Error> { + { + *self.event.lock() = Some(BehaviourEvent::VerifyRemoteIdentity); + } + + if *self.skip_chain_id_verify.lock() { + Ok(()) + } else { + Err(Error::WrongChainId) + } + } + + pub fn skip_chain_id_verify(&self, result: bool) { + *self.skip_chain_id_verify.lock() = result; + } +} + +#[test] +fn should_reject_unencrypted_connection() { + let mut identify = IdentifyProtocol::new(); + let proto_context = ProtocolContext::make_no_encrypted( + PROTOCOL_ID.into(), + SESSION_ID.into(), + SessionType::Inbound, + ); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + match identify.state { + State::FailedWithoutEncryption => (), + _ => panic!("should enter failed state"), + } + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_wait_client_identity_for_inbound_connection() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitIdentity, + context, + } => assert!( + context.timeout_abort_handle.is_some(), + "should set up wait timeout" + ), + _ => panic!("should enter failed state"), + } +} + +#[tokio::test] +async fn should_disconnect_if_wait_client_identity_timeout() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + let mut context = match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitIdentity, + context, + } => { + assert!( + context.timeout_abort_handle.is_some(), + "should set up wait timeout" + ); + context + } + _ => panic!("should enter failed state"), + }; + + context.set_timeout("override wait identity", Duration::from_millis(300)); + Delay::new(Duration::from_millis(700)).await; + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_register_opened_protocol() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let peer_id = proto_context + .session + .remote_pubkey + .as_ref() + .unwrap() + .peer_id(); + assert!(crate::protocols::OpenedProtocols::is_open( + &peer_id, + &PROTOCOL_ID.into() + )); +} + +#[tokio::test] +async fn should_send_identity_to_server_for_outbound_connection() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::WaitAck, + context, + } => assert!( + context.timeout_abort_handle.is_some(), + "should set up wait timeout" + ), + _ => panic!("should enter failed state"), + } + + match identify.behaviour.event() { + Some(BehaviourEvent::SendIdentity) => (), + _ => panic!("should send identity"), + } +} + +#[tokio::test] +async fn should_disconnect_if_wait_server_ack_timeout() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let mut context = match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::WaitAck, + context, + } => { + assert!( + context.timeout_abort_handle.is_some(), + "should set up wait timeout" + ); + context + } + _ => panic!("should enter failed state"), + }; + + match identify.behaviour.event() { + Some(BehaviourEvent::SendIdentity) => (), + _ => panic!("should send identity"), + } + + context.set_timeout("override wait ack", Duration::from_millis(300)); + Delay::new(Duration::from_millis(700)).await; + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_disconnect_if_exceed_max_message_size() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + let msg = Bytes::from("a".repeat(MAX_MESSAGE_SIZE + 1)); + identify.on_received(&IdentifyProtocolContext(&proto_context), msg); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_send_ack_if_identity_is_valid_on_server_side() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let identity = message::Identity::mock_valid().into_bytes().unwrap(); + identify.behaviour.skip_chain_id_verify(true); + identify.on_received(&IdentifyProtocolContext(&proto_context), identity); + + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitOpenProtocols, + context, + } => assert!( + context.timeout_abort_handle.is_some(), + "should set up wait open protocols timeout" + ), + _ => panic!("should enter wait open protocols state"), + } + + match identify.behaviour.event() { + Some(BehaviourEvent::SendAck) => (), + _ => panic!("should send ack"), + } +} + +#[tokio::test] +async fn should_disconnect_if_client_open_protocols_timeout() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let identity = message::Identity::mock_valid().into_bytes().unwrap(); + identify.behaviour.skip_chain_id_verify(true); + identify.on_received(&IdentifyProtocolContext(&proto_context), identity); + + let mut context = match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitOpenProtocols, + context, + } => { + assert!( + context.timeout_abort_handle.is_some(), + "should set up wait open protocols timeout" + ); + context + } + _ => panic!("should enter wait open protocols state"), + }; + + match identify.behaviour.event() { + Some(BehaviourEvent::SendAck) => (), + _ => panic!("should send ack"), + } + + context.set_timeout("override wait open protocols", Duration::from_millis(300)); + Delay::new(Duration::from_millis(700)).await; + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_disconnect_if_client_send_undecodeable_identity() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let msg = Bytes::from("a"); + identify.on_received(&IdentifyProtocolContext(&proto_context), msg); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } + + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } +} + +#[tokio::test] +async fn should_disconnect_if_client_send_invalid_identity() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let msg = message::Identity::mock_invalid().into_bytes().unwrap(); + identify.on_received(&IdentifyProtocolContext(&proto_context), msg); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } + + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } +} + +#[tokio::test] +async fn should_disconnect_if_client_send_different_chain_id() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let msg = message::Identity::mock_valid().into_bytes().unwrap(); + identify.behaviour.skip_chain_id_verify(false); + identify.on_received(&IdentifyProtocolContext(&proto_context), msg); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } + + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } +} + +#[tokio::test] +async fn should_disconnect_if_client_send_data_during_open_protocols() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let identity = message::Identity::mock_valid().into_bytes().unwrap(); + identify.behaviour.skip_chain_id_verify(true); + identify.on_received(&IdentifyProtocolContext(&proto_context), identity); + + match &identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitOpenProtocols, + context, + } => assert!( + context.timeout_abort_handle.is_some(), + "should set up wait open protocols timeout" + ), + _ => panic!("should enter wait open protocols state"), + } + + match identify.behaviour.event() { + Some(BehaviourEvent::SendAck) => (), + _ => panic!("should send ack"), + } + + identify.on_received( + &IdentifyProtocolContext(&proto_context), + Bytes::from_static(b"test"), + ); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_open_protocols_after_receive_valid_ack_from_server() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let ack = message::Acknowledge::mock_valid().into_bytes().unwrap(); + identify.on_received(&IdentifyProtocolContext(&proto_context), ack); + + match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::OpenOtherProtocols, + .. + } => (), + _ => panic!("should enter wait open protocols state"), + } + + match proto_context.control().event() { + Some(ControlEvent::OpenProtocols { + session_id, + target_proto, + }) if session_id == SESSION_ID.into() && target_proto == TargetProtocol::All => (), + _ => panic!("should open protocols"), + } +} + +#[tokio::test] +async fn should_disconnect_if_server_send_undecodeable_ack() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + identify.on_received( + &IdentifyProtocolContext(&proto_context), + Bytes::from_static(b"xxx"), + ); + + match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_disconnect_if_server_send_invalid_ack() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let ack = message::Acknowledge::mock_invalid().into_bytes().unwrap(); + identify.on_received(&IdentifyProtocolContext(&proto_context), ack); + + match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_disconnect_if_server_send_data_during_open_protocols() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let ack = message::Acknowledge::mock_valid().into_bytes().unwrap(); + identify.on_received(&IdentifyProtocolContext(&proto_context), ack); + + match &identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::OpenOtherProtocols, + .. + } => (), + _ => panic!("should enter wait open protocols state"), + } + + match proto_context.control().event() { + Some(ControlEvent::OpenProtocols { + session_id, + target_proto, + }) if session_id == SESSION_ID.into() && target_proto == TargetProtocol::All => (), + _ => panic!("should open protocols"), + } + + identify.on_received( + &IdentifyProtocolContext(&proto_context), + Bytes::from_static(b"test"), + ); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_disconnect_if_either_send_data_no_in_negotiate_procedure() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + identify.on_received( + &IdentifyProtocolContext(&proto_context), + Bytes::from_static(b"test"), + ); + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } +} + +#[tokio::test] +async fn should_wake_wait_identification_after_call_finish_identify() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Inbound); + + let peer_id = proto_context + .session + .remote_pubkey + .as_ref() + .unwrap() + .peer_id(); + + let wait_fut = IdentifyProtocol::wait(peer_id); + + tokio::spawn(async move { + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + let identity = message::Identity::mock_valid().into_bytes().unwrap(); + identify.behaviour.skip_chain_id_verify(true); + identify.on_received(&IdentifyProtocolContext(&proto_context), identity); + + match identify.state { + State::ServerNegotiate { + procedure: ServerProcedure::WaitOpenProtocols, + context, + } => assert!( + context.timeout_abort_handle.is_some(), + "should set up wait open protocols timeout" + ), + _ => panic!("should enter wait open protocols state"), + } + + match identify.behaviour.event() { + Some(BehaviourEvent::SendAck) => (), + _ => panic!("should send ack"), + } + }); + + assert!(wait_fut.await.is_ok(), "should be ok if pass identify"); +} + +#[tokio::test] +async fn should_pass_error_to_wait_identification_result_if_failed_identify() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + let peer_id = proto_context + .session + .remote_pubkey + .as_ref() + .unwrap() + .peer_id(); + + let wait_fut = IdentifyProtocol::wait(peer_id); + + tokio::spawn(async move { + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + + identify.on_received( + &IdentifyProtocolContext(&proto_context), + Bytes::from_static(b"xxx"), + ); + + match identify.state { + State::ClientNegotiate { + procedure: ClientProcedure::Failed, + .. + } => (), + _ => panic!("should enter failed state"), + } + + match proto_context.control().event() { + Some(ControlEvent::Disconnect { session_id }) if session_id == SESSION_ID.into() => (), + _ => panic!("should disconnect"), + } + }); + + match wait_fut.await { + Err(Error::DecodeAckFailed) => (), + _ => panic!("should pass decode failed error"), + } +} + +#[tokio::test] +async fn should_pass_disconnected_to_wait_identification_result_if_still_wait_identify_but_disconnected( +) { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + let peer_id = proto_context + .session + .remote_pubkey + .as_ref() + .unwrap() + .peer_id(); + + let wait_fut = IdentifyProtocol::wait(peer_id); + + tokio::spawn(async move { + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + identify.on_disconnected(&IdentifyProtocolContext(&proto_context)); + }); + + match wait_fut.await { + Err(Error::Disconnected) => (), + _ => panic!("should pass disconnected error"), + } +} + +#[tokio::test] +async fn should_remove_from_opened_protocols_after_disconnect() { + let mut identify = IdentifyProtocol::new(); + let proto_context = + ProtocolContext::make(PROTOCOL_ID.into(), SESSION_ID.into(), SessionType::Outbound); + + let peer_id = proto_context + .session + .remote_pubkey + .as_ref() + .unwrap() + .peer_id(); + + identify.on_connected(&IdentifyProtocolContext(&proto_context)); + identify.on_disconnected(&IdentifyProtocolContext(&proto_context)); + + assert_eq!( + crate::protocols::OpenedProtocols::is_open(&peer_id, &PROTOCOL_ID.into()), + false + ); +} diff --git a/core/network/src/protocols/mod.rs b/core/network/src/protocols/mod.rs index e7f417bb9..61aa4aa0c 100644 --- a/core/network/src/protocols/mod.rs +++ b/core/network/src/protocols/mod.rs @@ -3,9 +3,9 @@ mod r#macro; mod core; mod discovery; -mod identify; mod ping; mod transmitter; -pub use self::core::{CoreProtocol, CoreProtocolBuilder}; +pub mod identify; +pub use self::core::{CoreProtocol, CoreProtocolBuilder, OpenedProtocols}; pub use transmitter::{ReceivedMessage, Recipient, Transmitter, TransmitterMessage}; diff --git a/core/network/src/protocols/ping/protocol.rs b/core/network/src/protocols/ping/protocol.rs index d55872ab0..83b7109ce 100644 --- a/core/network/src/protocols/ping/protocol.rs +++ b/core/network/src/protocols/ping/protocol.rs @@ -118,15 +118,17 @@ impl ServiceProtocol for PingProtocol { self.connected_session_ids .entry(session.id) .or_insert_with(|| PingStatus { - last_ping: SystemTime::now(), + last_ping: SystemTime::now(), processing: false, - peer_id, + peer_id: peer_id.clone(), }); debug!( "proto id [{}] open on session [{}], address: [{}], type: [{:?}], version: {}", context.proto_id, session.id, session.address, session.ty, version ); debug!("connected sessions are: {:?}", self.connected_session_ids); + + crate::protocols::OpenedProtocols::register(peer_id, context.proto_id); } None => { if context.disconnect(session.id).is_err() { diff --git a/core/network/src/protocols/transmitter.rs b/core/network/src/protocols/transmitter.rs index 446fd92fa..02b0cf17a 100644 --- a/core/network/src/protocols/transmitter.rs +++ b/core/network/src/protocols/transmitter.rs @@ -9,6 +9,8 @@ use tentacle::builder::MetaBuilder; use tentacle::service::{ProtocolHandle, ProtocolMeta}; use tentacle::ProtocolId; +use crate::peer_manager::PeerManagerHandle; + use self::behaviour::TransmitterBehaviour; use self::protocol::TransmitterProtocol; pub use message::{ReceivedMessage, Recipient, TransmitterMessage}; @@ -22,12 +24,17 @@ pub const MAX_CHUNK_SIZE: usize = 4 * 1000 * 1000; // 4MB pub struct Transmitter { data_tx: UnboundedSender, pub(crate) behaviour: TransmitterBehaviour, + peer_mgr: PeerManagerHandle, } impl Transmitter { - pub fn new(data_tx: UnboundedSender) -> Self { + pub fn new(data_tx: UnboundedSender, peer_mgr: PeerManagerHandle) -> Self { let behaviour = TransmitterBehaviour::new(); - Transmitter { data_tx, behaviour } + Transmitter { + data_tx, + behaviour, + peer_mgr, + } } pub fn build_meta(self, protocol_id: ProtocolId) -> ProtocolMeta { @@ -36,7 +43,7 @@ impl Transmitter { .name(name!(NAME)) .support_versions(support_versions!(SUPPORT_VERSIONS)) .session_handle(move || { - let proto = TransmitterProtocol::new(self.data_tx.clone()); + let proto = TransmitterProtocol::new(self.data_tx.clone(), self.peer_mgr.clone()); ProtocolHandle::Callback(Box::new(proto)) }) .build() diff --git a/core/network/src/protocols/transmitter/protocol.rs b/core/network/src/protocols/transmitter/protocol.rs index 9f730575a..acaa4531c 100644 --- a/core/network/src/protocols/transmitter/protocol.rs +++ b/core/network/src/protocols/transmitter/protocol.rs @@ -5,20 +5,24 @@ use protocol::Bytes; use tentacle::context::ProtocolContextMutRef; use tentacle::traits::SessionProtocol; +use crate::peer_manager::PeerManagerHandle; + use super::message::{ReceivedMessage, SeqChunkMessage}; use super::{DATA_SEQ_TIMEOUT, MAX_CHUNK_SIZE}; pub struct TransmitterProtocol { data_tx: UnboundedSender, + peer_mgr: PeerManagerHandle, data_buf: Vec, current_data_seq: u64, first_seq_bytes_at: Instant, } impl TransmitterProtocol { - pub fn new(data_tx: UnboundedSender) -> Self { + pub fn new(data_tx: UnboundedSender, peer_mgr: PeerManagerHandle) -> Self { TransmitterProtocol { data_tx, + peer_mgr, data_buf: Vec::new(), current_data_seq: 0, first_seq_bytes_at: Instant::now(), @@ -27,6 +31,23 @@ impl TransmitterProtocol { } impl SessionProtocol for TransmitterProtocol { + fn connected(&mut self, context: ProtocolContextMutRef, _version: &str) { + if !self.peer_mgr.contains_session(context.session.id) { + let _ = context.close_protocol(context.session.id, context.proto_id()); + return; + } + + let peer_id = match context.session.remote_pubkey.as_ref() { + Some(pubkey) => pubkey.peer_id(), + None => { + log::warn!("peer connection must be encrypted"); + let _ = context.disconnect(context.session.id); + return; + } + }; + crate::protocols::OpenedProtocols::register(peer_id, context.proto_id()); + } + fn received(&mut self, ctx: ProtocolContextMutRef, data: Bytes) { let peer_id = match ctx.session.remote_pubkey.as_ref() { Some(pk) => pk.peer_id(), diff --git a/core/network/src/service.rs b/core/network/src/service.rs index 41ff6fb7c..5170f6a6c 100644 --- a/core/network/src/service.rs +++ b/core/network/src/service.rs @@ -13,6 +13,7 @@ use protocol::traits::{ Context, Gossip, MessageCodec, MessageHandler, Network, PeerTag, PeerTrust, Priority, Rpc, TrustFeedback, }; +use protocol::types::Hash; use protocol::{Bytes, ProtocolResult}; use tentacle::secio::PeerId; @@ -222,7 +223,7 @@ impl NetworkService { .ping(config.ping_interval, config.ping_timeout, mgr_tx.clone()) .identify(peer_mgr_handle.clone(), mgr_tx.clone()) .discovery(peer_mgr_handle.clone(), mgr_tx.clone(), disc_sync_interval) - .transmitter(recv_data_tx.clone()) + .transmitter(recv_data_tx.clone(), peer_mgr_handle.clone()) .build(); let transmitter = proto.transmitter(); @@ -356,6 +357,10 @@ impl NetworkService { self.config.secio_keypair.peer_id() } + pub fn set_chain_id(&self, chain_id: Hash) { + self.peer_mgr_handle.set_chain_id(chain_id); + } + pub async fn listen(&mut self, socket_addr: SocketAddr) -> ProtocolResult<()> { if let Some(NetworkConnectionService::NoListen(conn_srv)) = &mut self.net_conn_srv { debug!("network: listen to {}", socket_addr); diff --git a/core/network/src/test/mock.rs b/core/network/src/test/mock.rs index 460714082..96645828b 100644 --- a/core/network/src/test/mock.rs +++ b/core/network/src/test/mock.rs @@ -1,10 +1,12 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use parking_lot::Mutex; +use protocol::Bytes; use tentacle::multiaddr::Multiaddr; -use tentacle::secio::PublicKey; -use tentacle::service::SessionType; -use tentacle::SessionId; +use tentacle::secio::{PublicKey, SecioKeyPair}; +use tentacle::service::{SessionType, TargetProtocol}; +use tentacle::{ProtocolId, SessionId}; #[derive(Clone, Debug)] pub struct SessionContext { @@ -16,6 +18,37 @@ pub struct SessionContext { } impl SessionContext { + pub fn no_encrypted(id: SessionId, ty: SessionType) -> Self { + let address = "/ip4/47.111.169.36/tcp/3000".parse().expect("multiaddr"); + + SessionContext { + id, + address, + ty, + remote_pubkey: None, + pending_data_size: Arc::new(AtomicUsize::new(0)), + } + } + + pub fn random(id: SessionId, ty: SessionType) -> Self { + let keypair = SecioKeyPair::secp256k1_generated(); + let pubkey = keypair.public_key(); + let peer_id = pubkey.peer_id(); + + let address = { + let addr_str = format!("/ip4/47.111.169.36/tcp/3000/p2p/{}", peer_id.to_base58()); + addr_str.parse().expect("multiaddr") + }; + + SessionContext { + id, + address, + ty, + remote_pubkey: Some(pubkey), + pending_data_size: Arc::new(AtomicUsize::new(0)), + } + } + pub fn make(id: SessionId, address: Multiaddr, ty: SessionType, pubkey: PublicKey) -> Self { SessionContext { id, @@ -46,3 +79,106 @@ impl From> for SessionContext { } } } + +#[derive(Clone, PartialEq, Eq)] +pub enum ControlEvent { + SendMessage { + proto_id: ProtocolId, + session_id: SessionId, + msg: Bytes, + }, + Disconnect { + session_id: SessionId, + }, + OpenProtocols { + session_id: SessionId, + target_proto: TargetProtocol, + }, +} + +#[derive(Clone)] +pub struct ServiceControl { + pub event: Arc>>, +} + +impl Default for ServiceControl { + fn default() -> Self { + ServiceControl { + event: Arc::new(Mutex::new(None)), + } + } +} + +impl ServiceControl { + pub fn event(&self) -> Option { + self.event.lock().clone() + } + + pub fn quick_send_message_to( + &self, + session_id: SessionId, + proto_id: ProtocolId, + msg: Bytes, + ) -> Result<(), String> { + *self.event.lock() = Some(ControlEvent::SendMessage { + session_id, + proto_id, + msg, + }); + + Ok(()) + } + + pub fn disconnect(&self, session_id: SessionId) { + *self.event.lock() = Some(ControlEvent::Disconnect { session_id }); + } + + pub fn open_protocols( + &self, + session_id: SessionId, + target_proto: TargetProtocol, + ) -> Result<(), String> { + *self.event.lock() = Some(ControlEvent::OpenProtocols { + session_id, + target_proto, + }); + + Ok(()) + } +} + +pub struct ProtocolContext { + proto_id: ProtocolId, + pub session: SessionContext, + pub control: ServiceControl, +} + +impl ProtocolContext { + pub fn make_no_encrypted(proto_id: ProtocolId, id: SessionId, ty: SessionType) -> Self { + ProtocolContext { + proto_id, + session: SessionContext::no_encrypted(id, ty), + control: ServiceControl::default(), + } + } + + pub fn make(proto_id: ProtocolId, id: SessionId, ty: SessionType) -> Self { + ProtocolContext { + proto_id, + session: SessionContext::random(id, ty), + control: ServiceControl::default(), + } + } + + pub fn proto_id(&self) -> ProtocolId { + self.proto_id + } + + pub fn control(&self) -> &ServiceControl { + &self.control + } + + pub fn disconnect(&self, session_id: SessionId) { + self.control.disconnect(session_id) + } +} diff --git a/src/default_start.rs b/src/default_start.rs index 1ac901be9..1bb8bd72c 100644 --- a/src/default_start.rs +++ b/src/default_start.rs @@ -290,6 +290,9 @@ pub async fn start( protocol::init_address_hrp(metadata.bech32_address_hrp); } + // set chain id in network + network_service.set_chain_id(metadata.chain_id.clone()); + // set args in mempool mempool.set_args( metadata.timeout_gap, diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 000000000..c6619985c --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,40 @@ +#![allow(clippy::mutable_key_type)] + +pub mod node; + +use std::net::TcpListener; +use std::path::PathBuf; +use std::sync::atomic::{AtomicU16, Ordering}; + +use protocol::types::Hash; +use protocol::BytesMut; +use rand::{rngs::OsRng, RngCore}; + +static AVAILABLE_PORT: AtomicU16 = AtomicU16::new(2000); + +pub fn tmp_dir() -> PathBuf { + let mut tmp_dir = std::env::temp_dir(); + let sub_dir = { + let mut random_bytes = [0u8; 32]; + OsRng.fill_bytes(&mut random_bytes); + Hash::digest(BytesMut::from(random_bytes.as_ref()).freeze()).as_hex() + }; + + tmp_dir.push(sub_dir + "/"); + tmp_dir +} + +pub fn available_port_pair() -> (u16, u16) { + (available_port(), available_port()) +} + +fn available_port() -> u16 { + let is_available = |port| -> bool { TcpListener::bind(("127.0.0.1", port)).is_ok() }; + + loop { + let port = AVAILABLE_PORT.fetch_add(1, Ordering::SeqCst); + if is_available(port) { + return port; + } + } +} diff --git a/tests/trust_metric_all/node.rs b/tests/common/node.rs similarity index 61% rename from tests/trust_metric_all/node.rs rename to tests/common/node.rs index 8a3c2ea22..efd676e56 100644 --- a/tests/trust_metric_all/node.rs +++ b/tests/common/node.rs @@ -1,9 +1,7 @@ -pub mod client_node; +pub mod config; pub mod consts; +pub mod diagnostic; pub mod full_node; pub mod sync; -mod config; -mod diagnostic; - pub use diagnostic::TwinEvent; diff --git a/tests/trust_metric_all/node/config.rs b/tests/common/node/config.rs similarity index 100% rename from tests/trust_metric_all/node/config.rs rename to tests/common/node/config.rs diff --git a/tests/trust_metric_all/node/consts.rs b/tests/common/node/consts.rs similarity index 100% rename from tests/trust_metric_all/node/consts.rs rename to tests/common/node/consts.rs diff --git a/tests/trust_metric_all/node/diagnostic.rs b/tests/common/node/diagnostic.rs similarity index 100% rename from tests/trust_metric_all/node/diagnostic.rs rename to tests/common/node/diagnostic.rs diff --git a/tests/trust_metric_all/node/full_node.rs b/tests/common/node/full_node.rs similarity index 95% rename from tests/trust_metric_all/node/full_node.rs rename to tests/common/node/full_node.rs index e6c4b66dd..0985dce9b 100644 --- a/tests/trust_metric_all/node/full_node.rs +++ b/tests/common/node/full_node.rs @@ -71,12 +71,12 @@ impl From for ProtocolError { } // Note: inject runnning_status -pub async fn run(listen_port: u16, sync: Sync) { +pub async fn run(listen_port: u16, seckey: String, sync: Sync) { let builder = MutaBuilder::new() .config_path(consts::CHAIN_CONFIG_PATH) .genesis_path(consts::CHAIN_GENESIS_PATH) .service_mapping(DefaultServiceMapping {}); let muta = builder.build(listen_port).expect("build"); - muta.run(sync).await.expect("run"); + muta.run(seckey, sync).await.expect("run"); } diff --git a/tests/trust_metric_all/node/full_node/builder.rs b/tests/common/node/full_node/builder.rs similarity index 96% rename from tests/trust_metric_all/node/full_node/builder.rs rename to tests/common/node/full_node/builder.rs index 3f670bb47..cd321b2e0 100644 --- a/tests/trust_metric_all/node/full_node/builder.rs +++ b/tests/common/node/full_node/builder.rs @@ -84,7 +84,7 @@ impl Muta { } } - pub async fn run(self, sync: Sync) -> ProtocolResult<()> { + pub async fn run(self, seckey: String, sync: Sync) -> ProtocolResult<()> { // run muta let memory_db = MemoryDB::default(); @@ -93,6 +93,7 @@ impl Muta { self.config, Arc::clone(&self.service_mapping), memory_db, + seckey, sync, ) .await?; diff --git a/tests/trust_metric_all/node/full_node/default_start.rs b/tests/common/node/full_node/default_start.rs similarity index 98% rename from tests/trust_metric_all/node/full_node/default_start.rs rename to tests/common/node/full_node/default_start.rs index 2b8e216ef..fcbd3f30c 100644 --- a/tests/trust_metric_all/node/full_node/default_start.rs +++ b/tests/common/node/full_node/default_start.rs @@ -4,7 +4,6 @@ use super::diagnostic::{ }; /// Almost same as src/default_start.rs, only remove graphql service. use super::{config::Config, consts, error::MainError, memory_db::MemoryDB, Sync}; -use crate::trust_metric_all::common; use std::collections::HashMap; use std::convert::TryFrom; @@ -135,6 +134,7 @@ pub async fn start( config: Config, service_mapping: Arc, db: MemoryDB, + seckey: String, sync: Sync, ) -> ProtocolResult<()> { log::info!("node starts"); @@ -160,8 +160,6 @@ pub async fn start( .write_timeout(config.network.write_timeout) .recv_buffer_size(config.network.recv_buffer_size); - let network_privkey = config.privkey.as_string_trim0x(); - let mut bootstrap_pairs = vec![]; if let Some(bootstrap) = &config.network.bootstraps { for bootstrap in bootstrap.iter() { @@ -174,7 +172,7 @@ pub async fn start( let network_config = network_config .bootstraps(bootstrap_pairs)? .allowlist(allowlist)? - .secio_keypair(network_privkey)?; + .secio_keypair(seckey.clone())?; let mut network_service = NetworkService::new(network_config); network_service .listen(config.network.listening_address) @@ -225,7 +223,7 @@ pub async fn start( ); // Create full transactions wal - let wal_path = common::tmp_dir() + let wal_path = crate::common::tmp_dir() .to_str() .expect("wal path string") .to_string(); @@ -247,6 +245,9 @@ pub async fn start( let metadata: Metadata = serde_json::from_str(&exec_resp.succeed_data).expect("Decode metadata failed!"); + // set chain id in network + network_service.set_chain_id(Hash::from_hex(consts::CHAIN_ID).expect("chain id")); + // set args in mempool mempool.set_args( metadata.timeout_gap, diff --git a/tests/trust_metric_all/node/full_node/error.rs b/tests/common/node/full_node/error.rs similarity index 100% rename from tests/trust_metric_all/node/full_node/error.rs rename to tests/common/node/full_node/error.rs diff --git a/tests/trust_metric_all/node/full_node/memory_db.rs b/tests/common/node/full_node/memory_db.rs similarity index 100% rename from tests/trust_metric_all/node/full_node/memory_db.rs rename to tests/common/node/full_node/memory_db.rs diff --git a/tests/trust_metric_all/node/sync.rs b/tests/common/node/sync.rs similarity index 97% rename from tests/trust_metric_all/node/sync.rs rename to tests/common/node/sync.rs index 137ce2e17..b013ff978 100644 --- a/tests/trust_metric_all/node/sync.rs +++ b/tests/common/node/sync.rs @@ -120,6 +120,12 @@ impl Sync { } } +impl Default for Sync { + fn default() -> Self { + Sync::new() + } +} + impl Drop for Sync { fn drop(&mut self) { self.connected.store(false, Ordering::SeqCst); diff --git a/tests/trust_metric.rs b/tests/trust_metric.rs index f571c270d..de74643d2 100644 --- a/tests/trust_metric.rs +++ b/tests/trust_metric.rs @@ -1,3 +1,4 @@ /// NOTE: Test may panic after drop full node future, which is /// expected. +pub mod common; mod trust_metric_all; diff --git a/tests/trust_metric_all/node/client_node.rs b/tests/trust_metric_all/client_node.rs similarity index 88% rename from tests/trust_metric_all/node/client_node.rs rename to tests/trust_metric_all/client_node.rs index ffc1f5aec..87794b8c9 100644 --- a/tests/trust_metric_all/node/client_node.rs +++ b/tests/trust_metric_all/client_node.rs @@ -1,35 +1,31 @@ -use super::diagnostic::{ - TrustNewIntervalReq, TrustTwinEventReq, TwinEvent, GOSSIP_TRUST_NEW_INTERVAL, - GOSSIP_TRUST_TWIN_EVENT, -}; -use super::{ - config::Config, - consts, - sync::{Sync, SyncError, SyncEvent}, -}; - -use common_crypto::{PrivateKey, Secp256k1PrivateKey}; +use std::collections::HashSet; +use std::convert::TryFrom; +use std::iter::FromIterator; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::ops::Deref; +use std::str::FromStr; + +use common_crypto::{PrivateKey, PublicKey, Secp256k1PrivateKey, ToPublicKey}; use core_consensus::message::{ FixedBlock, FixedHeight, BROADCAST_HEIGHT, RPC_RESP_SYNC_PULL_BLOCK, RPC_SYNC_PULL_BLOCK, }; use core_network::{ - DiagnosticEvent, NetworkConfig, NetworkService, NetworkServiceHandle, PeerId, TrustReport, + DiagnosticEvent, NetworkConfig, NetworkService, NetworkServiceHandle, PeerId, PeerIdExt, + TrustReport, }; use derive_more::Display; -use protocol::{ - async_trait, - traits::{Context, Gossip, MessageCodec, MessageHandler, Priority, Rpc, TrustFeedback}, - types::{Address, Block, BlockHeader, Hash, Proof}, - Bytes, +use protocol::traits::{ + Context, Gossip, MessageCodec, MessageHandler, Priority, Rpc, TrustFeedback, }; +use protocol::types::{Address, Block, BlockHeader, Hash, Proof}; +use protocol::{async_trait, Bytes}; -use std::{ - collections::HashSet, - iter::FromIterator, - net::{IpAddr, Ipv4Addr, SocketAddr}, - ops::Deref, - str::FromStr, +use crate::common::node::consts; +use crate::common::node::diagnostic::{ + TrustNewIntervalReq, TrustTwinEventReq, TwinEvent, GOSSIP_TRUST_NEW_INTERVAL, + GOSSIP_TRUST_TWIN_EVENT, }; +use crate::common::node::sync::{Sync, SyncError, SyncEvent}; #[derive(Debug, Display)] pub enum ClientNodeError { @@ -89,8 +85,13 @@ pub struct ClientNode { pub sync: Sync, } -pub async fn connect(full_node_port: u16, listen_port: u16, sync: Sync) -> ClientNode { - let full_node_peer_id = full_node_peer_id(); +pub async fn connect( + full_node_port: u16, + full_seckey: String, + listen_port: u16, + sync: Sync, +) -> ClientNode { + let full_node_peer_id = full_node_peer_id(&full_seckey); let full_node_addr = format!("127.0.0.1:{}", full_node_port); let config = NetworkConfig::new() @@ -104,6 +105,8 @@ pub async fn connect(full_node_port: u16, listen_port: u16, sync: Sync) -> Clien let mut network = NetworkService::new(config); let handle = network.handle(); + network.set_chain_id(Hash::from_hex(consts::CHAIN_ID).expect("chain id")); + network .register_endpoint_handler( RPC_SYNC_PULL_BLOCK, @@ -290,14 +293,13 @@ impl Deref for ClientNode { } } -fn full_node_peer_id() -> PeerId { - let config: Config = - common_config_parser::parse(&consts::CHAIN_CONFIG_PATH).expect("parse chain config.toml"); - - let mut bootstraps = config.network.bootstraps.expect("config.toml full node"); - let full_node = bootstraps.pop().expect("there should be one bootstrap"); - - full_node.peer_id.parse().expect("parse peer id") +fn full_node_peer_id(full_seckey: &str) -> PeerId { + let seckey = { + let key = hex::decode(full_seckey).expect("hex private key string"); + Secp256k1PrivateKey::try_from(key.as_ref()).expect("valid private key") + }; + let pubkey = seckey.pub_key(); + PeerId::from_pubkey_bytes(pubkey.to_bytes()).expect("valid public key") } fn mock_block(height: u64) -> Block { diff --git a/tests/trust_metric_all/common.rs b/tests/trust_metric_all/common.rs index 8cfa6d129..166c443f2 100644 --- a/tests/trust_metric_all/common.rs +++ b/tests/trust_metric_all/common.rs @@ -1,34 +1,14 @@ -use super::node::consts; - use common_crypto::{ Crypto, PrivateKey, PublicKey, Secp256k1, Secp256k1PrivateKey, Signature, ToPublicKey, }; -use protocol::{ - fixed_codec::FixedCodec, - types::{Address, Hash, JsonString, RawTransaction, SignedTransaction, TransactionRequest}, - Bytes, BytesMut, +use protocol::fixed_codec::FixedCodec; +use protocol::types::{ + Address, Hash, JsonString, RawTransaction, SignedTransaction, TransactionRequest, }; +use protocol::{Bytes, BytesMut}; use rand::{rngs::OsRng, RngCore}; -use std::{ - net::TcpListener, - path::PathBuf, - sync::atomic::{AtomicU16, Ordering}, -}; - -static AVAILABLE_PORT: AtomicU16 = AtomicU16::new(2000); - -pub fn tmp_dir() -> PathBuf { - let mut tmp_dir = std::env::temp_dir(); - let sub_dir = { - let mut random_bytes = [0u8; 32]; - OsRng.fill_bytes(&mut random_bytes); - Hash::digest(BytesMut::from(random_bytes.as_ref()).freeze()).as_hex() - }; - - tmp_dir.push(sub_dir + "/"); - tmp_dir -} +use crate::common::node::consts; pub struct SignedTransactionBuilder { chain_id: Hash, @@ -112,18 +92,3 @@ impl SignedTransactionBuilder { pub fn stx_builder() -> SignedTransactionBuilder { SignedTransactionBuilder::default() } - -pub fn available_port_pair() -> (u16, u16) { - (available_port(), available_port()) -} - -fn available_port() -> u16 { - let is_available = |port| -> bool { TcpListener::bind(("127.0.0.1", port)).is_ok() }; - - loop { - let port = AVAILABLE_PORT.fetch_add(1, Ordering::SeqCst); - if is_available(port) { - return port; - } - } -} diff --git a/tests/trust_metric_all/consensus.rs b/tests/trust_metric_all/consensus.rs index 1a592612b..ba7d3b74f 100644 --- a/tests/trust_metric_all/consensus.rs +++ b/tests/trust_metric_all/consensus.rs @@ -1,11 +1,12 @@ -use super::{node::client_node::ClientNodeError, trust_test}; - use core_consensus::message::{ Choke, Proposal, Vote, BROADCAST_HEIGHT, END_GOSSIP_AGGREGATED_VOTE, END_GOSSIP_SIGNED_CHOKE, END_GOSSIP_SIGNED_PROPOSAL, END_GOSSIP_SIGNED_VOTE, QC, }; use protocol::traits::TrustFeedback; +use super::client_node::ClientNodeError; +use super::trust_test; + #[test] fn should_be_disconnected_for_repeated_undecodeable_proposal_within_four_intervals() { trust_test(move |client_node| { diff --git a/tests/trust_metric_all/mempool.rs b/tests/trust_metric_all/mempool.rs index 5b3496a08..0e83ef824 100644 --- a/tests/trust_metric_all/mempool.rs +++ b/tests/trust_metric_all/mempool.rs @@ -1,8 +1,10 @@ -use super::{common, node::client_node::ClientNodeError, trust_test}; - use core_mempool::{MsgNewTxs, END_GOSSIP_NEW_TXS}; use protocol::{traits::TrustFeedback, types::Hash, Bytes}; +use super::client_node::ClientNodeError; +use super::common; +use super::trust_test; + #[test] fn should_report_good_on_valid_transaction() { trust_test(move |client_node| { diff --git a/tests/trust_metric_all/mod.rs b/tests/trust_metric_all/mod.rs index d1e811977..c78bedc15 100644 --- a/tests/trust_metric_all/mod.rs +++ b/tests/trust_metric_all/mod.rs @@ -1,31 +1,42 @@ #![allow(clippy::mutable_key_type)] +mod client_node; mod common; mod consensus; mod logger; mod mempool; -mod node; +use std::panic; + +use common_crypto::{PrivateKey, Secp256k1PrivateKey}; use futures::future::BoxFuture; -use node::client_node::{ClientNode, ClientNodeError}; -use node::sync::Sync; -use std::panic; +use crate::common::node::sync::Sync; +use crate::common::{available_port_pair, node}; +use client_node::{ClientNode, ClientNodeError}; fn trust_test(test: impl FnOnce(ClientNode) -> BoxFuture<'static, ()> + Send + 'static) { - let (full_port, client_port) = common::available_port_pair(); + let (full_port, client_port) = available_port_pair(); let mut rt = tokio::runtime::Runtime::new().expect("create runtime"); let local = tokio::task::LocalSet::new(); local.block_on(&mut rt, async move { let sync = Sync::new(); - tokio::task::spawn_local(node::full_node::run(full_port, sync.clone())); + let full_seckey = { + let key = Secp256k1PrivateKey::generate(&mut rand::rngs::OsRng); + hex::encode(key.to_bytes()).to_string() + }; + tokio::task::spawn_local(node::full_node::run( + full_port, + full_seckey.clone(), + sync.clone(), + )); // Wait full node network initialization sync.wait().await; let handle = tokio::spawn(async move { - let client_node = node::client_node::connect(full_port, client_port, sync).await; + let client_node = client_node::connect(full_port, full_seckey, client_port, sync).await; test(client_node).await; }); diff --git a/tests/verify_chain_id.rs b/tests/verify_chain_id.rs new file mode 100644 index 000000000..13533fc4d --- /dev/null +++ b/tests/verify_chain_id.rs @@ -0,0 +1,285 @@ +/// NOTE: Test may panic after drop full node future, which is +/// expected. +pub mod common; + +use std::convert::TryFrom; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::ops::Deref; + +use common_crypto::{PrivateKey, PublicKey, Secp256k1PrivateKey, ToPublicKey}; +use core_consensus::message::{ + FixedBlock, FixedHeight, BROADCAST_HEIGHT, RPC_RESP_SYNC_PULL_BLOCK, RPC_SYNC_PULL_BLOCK, +}; +use core_network::{ + DiagnosticEvent, NetworkConfig, NetworkService, NetworkServiceHandle, PeerId, PeerIdExt, +}; +use derive_more::Display; +use protocol::traits::{Context, MessageCodec, MessageHandler, Priority, Rpc, TrustFeedback}; +use protocol::types::{Block, Hash}; +use protocol::{async_trait, Bytes}; + +use crate::common::available_port_pair; +use crate::common::node::consts; +use crate::common::node::full_node; +use crate::common::node::sync::{Sync, SyncError}; + +#[test] +fn should_be_disconnected_due_to_different_chain_id() { + let (full_port, client_port) = available_port_pair(); + let mut rt = tokio::runtime::Runtime::new().expect("create runtime"); + let local = tokio::task::LocalSet::new(); + + local.block_on(&mut rt, async move { + let sync = Sync::new(); + let full_seckey = { + let key = Secp256k1PrivateKey::generate(&mut rand::rngs::OsRng); + hex::encode(key.to_bytes()).to_string() + }; + tokio::task::spawn_local(full_node::run(full_port, full_seckey.clone(), sync.clone())); + + // Wait full node network initialization + sync.wait().await; + + let chain_id = Hash::digest(Bytes::from_static(b"beautiful world")); + let full_node_peer_id = full_node_peer_id(&full_seckey); + let full_node_addr = format!("127.0.0.1:{}", full_port); + + let config = NetworkConfig::new() + .ping_interval(consts::NETWORK_PING_INTERVAL) + .peer_trust_metric(consts::NETWORK_TRUST_METRIC_INTERVAL, None) + .expect("peer trust") + .bootstraps(vec![(full_node_peer_id.to_base58(), full_node_addr)]) + .expect("test node config"); + + let mut network = NetworkService::new(config); + + network.set_chain_id(chain_id); + + network + .register_endpoint_handler( + BROADCAST_HEIGHT, + Box::new(ReceiveRemoteHeight(sync.clone())), + ) + .expect("register remote height"); + + let hook_fn = |sync: Sync| -> _ { + Box::new(move |event: DiagnosticEvent| { + // We only care connected event on client node + if let DiagnosticEvent::NewSession = event { + sync.emit(event) + } + }) + }; + network.register_diagnostic_hook(hook_fn(sync.clone())); + + network + .listen(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), + client_port, + )) + .await + .expect("test node listen"); + tokio::spawn(network); + + match sync.recv().await { + Err(SyncError::Disconected) => (), + Err(err) => panic!("unexpected err {}", err), + Ok(event) => panic!("unexpected event {}", event), + } + }); +} + +#[test] +fn should_be_connected_with_same_chain_id() { + let (full_port, client_port) = available_port_pair(); + let mut rt = tokio::runtime::Runtime::new().expect("create runtime"); + let local = tokio::task::LocalSet::new(); + + local.block_on(&mut rt, async move { + let sync = Sync::new(); + let full_seckey = { + let key = Secp256k1PrivateKey::generate(&mut rand::rngs::OsRng); + hex::encode(key.to_bytes()).to_string() + }; + tokio::task::spawn_local(full_node::run(full_port, full_seckey.clone(), sync.clone())); + + // Wait full node network initialization + sync.wait().await; + let chain_id = Hash::from_hex(consts::CHAIN_ID).expect("chain id"); + let client_node = + connect(full_port, full_seckey, chain_id, client_port, sync.clone()).await; + + let block = client_node.get_block(0).await.expect("get genesis"); + assert_eq!(block.header.height, 0); + }); +} + +#[derive(Debug, Display)] +enum ClientNodeError { + #[display(fmt = "not connected")] + NotConnected, + + #[display(fmt = "unexpected {}", _0)] + Unexpected(String), +} +impl std::error::Error for ClientNodeError {} + +impl From for ClientNodeError { + fn from(err: SyncError) -> Self { + match err { + SyncError::Recv(err) => ClientNodeError::Unexpected(err.to_string()), + SyncError::Timeout => ClientNodeError::Unexpected(err.to_string()), + SyncError::Disconected => ClientNodeError::NotConnected, + } + } +} + +type ClientResult = Result; + +struct ReceiveRemoteHeight(Sync); + +#[async_trait] +impl MessageHandler for ReceiveRemoteHeight { + type Message = u64; + + async fn process(&self, _: Context, msg: u64) -> TrustFeedback { + self.0.emit(DiagnosticEvent::RemoteHeight { height: msg }); + TrustFeedback::Neutral + } +} +struct ClientNode { + pub network: NetworkServiceHandle, + pub remote_peer_id: PeerId, + pub priv_key: Secp256k1PrivateKey, + pub sync: Sync, +} + +async fn connect( + full_node_port: u16, + full_seckey: String, + chain_id: Hash, + listen_port: u16, + sync: Sync, +) -> ClientNode { + let full_node_peer_id = full_node_peer_id(&full_seckey); + let full_node_addr = format!("127.0.0.1:{}", full_node_port); + + let config = NetworkConfig::new() + .ping_interval(consts::NETWORK_PING_INTERVAL) + .peer_trust_metric(consts::NETWORK_TRUST_METRIC_INTERVAL, None) + .expect("peer trust") + .bootstraps(vec![(full_node_peer_id.to_base58(), full_node_addr)]) + .expect("test node config"); + let priv_key = Secp256k1PrivateKey::generate(&mut rand::rngs::OsRng); + + let mut network = NetworkService::new(config); + let handle = network.handle(); + + network.set_chain_id(chain_id); + + network + .register_rpc_response::(RPC_RESP_SYNC_PULL_BLOCK) + .expect("register consensus rpc response pull block"); + + network + .register_endpoint_handler( + BROADCAST_HEIGHT, + Box::new(ReceiveRemoteHeight(sync.clone())), + ) + .expect("register remote height"); + + let hook_fn = |sync: Sync| -> _ { + Box::new(move |event: DiagnosticEvent| { + // We only care connected event on client node + if let DiagnosticEvent::NewSession = event { + sync.emit(event) + } + }) + }; + network.register_diagnostic_hook(hook_fn(sync.clone())); + + network + .listen(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), + listen_port, + )) + .await + .expect("test node listen"); + + tokio::spawn(network); + sync.wait_connected().await; + + ClientNode { + network: handle, + remote_peer_id: full_node_peer_id, + priv_key, + sync, + } +} + +impl ClientNode { + pub fn connected(&self) -> bool { + let diagnostic = &self.network.diagnostic; + let opt_session = diagnostic.session(&self.remote_peer_id); + + self.sync.is_connected() && opt_session.is_some() + } + + pub fn connected_session(&self, peer_id: &PeerId) -> Option { + if !self.connected() { + None + } else { + let diagnostic = &self.network.diagnostic; + let opt_session = diagnostic.session(peer_id); + + opt_session.map(|sid| sid.value()) + } + } + + pub async fn rpc(&self, endpoint: &str, msg: M) -> ClientResult + where + M: MessageCodec, + R: MessageCodec, + { + let sid = match self.connected_session(&self.remote_peer_id) { + Some(sid) => sid, + None => return Err(ClientNodeError::NotConnected), + }; + + let ctx = Context::new().with_value::("session_id", sid); + match self.call::(ctx, endpoint, msg, Priority::High).await { + Ok(resp) => Ok(resp), + Err(e) if e.to_string().to_lowercase().contains("timeout") && !self.connected() => { + Err(ClientNodeError::NotConnected) + } + Err(e) => { + let err_msg = format!("rpc to {} {}", endpoint, e); + Err(ClientNodeError::Unexpected(err_msg)) + } + } + } + + pub async fn get_block(&self, height: u64) -> ClientResult { + let resp = self + .rpc::<_, FixedBlock>(RPC_SYNC_PULL_BLOCK, FixedHeight::new(height)) + .await?; + Ok(resp.inner) + } +} + +impl Deref for ClientNode { + type Target = NetworkServiceHandle; + + fn deref(&self) -> &Self::Target { + &self.network + } +} + +fn full_node_peer_id(full_seckey: &str) -> PeerId { + let seckey = { + let key = hex::decode(full_seckey).expect("hex private key string"); + Secp256k1PrivateKey::try_from(key.as_ref()).expect("valid private key") + }; + let pubkey = seckey.pub_key(); + PeerId::from_pubkey_bytes(pubkey.to_bytes()).expect("valid public key") +}