From 35ce15dd2e769c2fb0a517c804a095f30556faf8 Mon Sep 17 00:00:00 2001 From: Alex Potsides Date: Thu, 29 Aug 2024 17:20:42 +0200 Subject: [PATCH] fix: make handshake abortable (#442) To allow doing things like having a single `AbortSignal` that can be used as a timeout for incoming connection establishment, allow passing it as an option to the `ConnectionEncrypter` `secureOutbound` and `secureInbound` methods. Previously we'd wrap the stream to be secured in an `AbortableSource`, however this has some [serious performance implications](https://github.com/ChainSafe/js-libp2p-gossipsub/pull/361) and it's generally better to just use a signal to cancel an ongoing operation instead of racing every chunk that comes out of the source. --- src/noise.ts | 33 ++++++++++++++++++++------------- src/performHandshake.ts | 17 +++++++++-------- test/noise.spec.ts | 27 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/noise.ts b/src/noise.ts index b469792..640b4a2 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -1,5 +1,5 @@ import { unmarshalPrivateKey } from '@libp2p/crypto/keys' -import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId } from '@libp2p/interface' +import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId, type AbortOptions } from '@libp2p/interface' import { peerIdFromKeys } from '@libp2p/peer-id' import { decode } from 'it-length-prefixed' import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream' @@ -72,10 +72,10 @@ export class Noise implements INoiseConnection { * @param connection - streaming iterable duplex that will be encrypted * @param remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer. */ - public async secureOutbound > = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise> + public async secureOutbound > = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise> public async secureOutbound > = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise> public async secureOutbound > = MultiaddrConnection> (...args: any[]): Promise> { - const { localPeer, connection, remotePeer } = this.parseArgs(args) + const { localPeer, connection, remotePeer, signal } = this.parseArgs(args) const wrappedConnection = lpStream( connection, @@ -96,7 +96,9 @@ export class Noise implements INoiseConnection { const handshake = await this.performHandshakeInitiator( wrappedConnection, privateKey, - remoteIdentityKey + remoteIdentityKey, { + signal + } ) const conn = await this.createSecureConnection(wrappedConnection, handshake) @@ -117,10 +119,10 @@ export class Noise implements INoiseConnection { * @param connection - streaming iterable duplex that will be encrypted. * @param remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. */ - public async secureInbound > = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise> + public async secureInbound > = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise> public async secureInbound > = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise> public async secureInbound > = MultiaddrConnection> (...args: any[]): Promise> { - const { localPeer, connection, remotePeer } = this.parseArgs(args) + const { localPeer, connection, remotePeer, signal } = this.parseArgs(args) const wrappedConnection = lpStream( connection, @@ -141,7 +143,9 @@ export class Noise implements INoiseConnection { const handshake = await this.performHandshakeResponder( wrappedConnection, privateKey, - remoteIdentityKey + remoteIdentityKey, { + signal + } ) const conn = await this.createSecureConnection(wrappedConnection, handshake) @@ -162,7 +166,8 @@ export class Noise implements INoiseConnection { connection: LengthPrefixedStream, // TODO: pass private key in noise constructor via Components privateKey: PrivateKey, - remoteIdentityKey?: Uint8Array | Uint8ArrayList + remoteIdentityKey?: Uint8Array | Uint8ArrayList, + options?: AbortOptions ): Promise { let result: HandshakeResult try { @@ -175,7 +180,7 @@ export class Noise implements INoiseConnection { prologue: this.prologue, s: this.staticKey, extensions: this.extensions - }) + }, options) this.metrics?.xxHandshakeSuccesses.increment() } catch (e: unknown) { this.metrics?.xxHandshakeErrors.increment() @@ -192,7 +197,8 @@ export class Noise implements INoiseConnection { connection: LengthPrefixedStream, // TODO: pass private key in noise constructor via Components privateKey: PrivateKey, - remoteIdentityKey?: Uint8Array | Uint8ArrayList + remoteIdentityKey?: Uint8Array | Uint8ArrayList, + options?: AbortOptions ): Promise { let result: HandshakeResult try { @@ -205,7 +211,7 @@ export class Noise implements INoiseConnection { prologue: this.prologue, s: this.staticKey, extensions: this.extensions - }) + }, options) this.metrics?.xxHandshakeSuccesses.increment() } catch (e: unknown) { this.metrics?.xxHandshakeErrors.increment() @@ -241,7 +247,7 @@ export class Noise implements INoiseConnection { * TODO: remove this after `libp2p@2.x.x` is released and only support the * newer style */ - private parseArgs > = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId } { + private parseArgs > = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId, signal?: AbortSignal } { // if the first argument is a peer id, we're using the libp2p@1.x.x style if (isPeerId(args[0])) { return { @@ -256,7 +262,8 @@ export class Noise implements INoiseConnection { return { localPeer: this.components.peerId, connection: args[0], - remotePeer: args[1] + remotePeer: args[1]?.remotePeer, + signal: args[1]?.signal } } } diff --git a/src/performHandshake.ts b/src/performHandshake.ts index e0bfd53..dc197ba 100644 --- a/src/performHandshake.ts +++ b/src/performHandshake.ts @@ -8,8 +8,9 @@ import { import { ZEROLEN, XXHandshakeState } from './protocol.js' import { createHandshakePayload, decodeHandshakePayload } from './utils.js' import type { HandshakeResult, HandshakeParams } from './types.js' +import type { AbortOptions } from '@libp2p/interface' -export async function performHandshakeInitiator (init: HandshakeParams): Promise { +export async function performHandshakeInitiator (init: HandshakeParams, options?: AbortOptions): Promise { const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init const payload = await createHandshakePayload(privateKey, s.publicKey, extensions) @@ -23,12 +24,12 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise logLocalStaticKeys(xx.s, log) log.trace('Stage 0 - Initiator starting to send first message.') - await connection.write(xx.writeMessageA(ZEROLEN)) + await connection.write(xx.writeMessageA(ZEROLEN), options) log.trace('Stage 0 - Initiator finished sending first message.') logLocalEphemeralKeys(xx.e, log) log.trace('Stage 1 - Initiator waiting to receive first message from responder...') - const plaintext = xx.readMessageB(await connection.read()) + const plaintext = xx.readMessageB(await connection.read(options)) log.trace('Stage 1 - Initiator received the message.') logRemoteEphemeralKey(xx.re, log) logRemoteStaticKey(xx.rs, log) @@ -38,7 +39,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise log.trace('All good with the signature!') log.trace('Stage 2 - Initiator sending third handshake message.') - await connection.write(xx.writeMessageC(payload)) + await connection.write(xx.writeMessageC(payload), options) log.trace('Stage 2 - Initiator sent message with signed payload.') const [cs1, cs2] = xx.ss.split() @@ -51,7 +52,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise } } -export async function performHandshakeResponder (init: HandshakeParams): Promise { +export async function performHandshakeResponder (init: HandshakeParams, options?: AbortOptions): Promise { const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init const payload = await createHandshakePayload(privateKey, s.publicKey, extensions) @@ -65,17 +66,17 @@ export async function performHandshakeResponder (init: HandshakeParams): Promise logLocalStaticKeys(xx.s, log) log.trace('Stage 0 - Responder waiting to receive first message.') - xx.readMessageA(await connection.read()) + xx.readMessageA(await connection.read(options)) log.trace('Stage 0 - Responder received first message.') logRemoteEphemeralKey(xx.re, log) log.trace('Stage 1 - Responder sending out first message with signed payload and static key.') - await connection.write(xx.writeMessageB(payload)) + await connection.write(xx.writeMessageB(payload), options) log.trace('Stage 1 - Responder sent the second handshake message with signed payload.') logLocalEphemeralKeys(xx.e, log) log.trace('Stage 2 - Responder waiting for third handshake message...') - const plaintext = xx.readMessageC(await connection.read()) + const plaintext = xx.readMessageC(await connection.read(options)) log.trace('Stage 2 - Responder received the message, finished handshake.') const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey) diff --git a/test/noise.spec.ts b/test/noise.spec.ts index b3e5882..dbe5539 100644 --- a/test/noise.spec.ts +++ b/test/noise.spec.ts @@ -179,4 +179,31 @@ describe('Noise', () => { assert(false, err.message) } }) + + it('should abort noise handshake', async () => { + const abortController = new AbortController() + abortController.abort() + + const noiseInit = new Noise({ + peerId: localPeer, + logger: defaultLogger() + }, { staticNoiseKey: undefined, extensions: undefined }) + const noiseResp = new Noise({ + peerId: remotePeer, + logger: defaultLogger() + }, { staticNoiseKey: undefined, extensions: undefined }) + + const [inboundConnection, outboundConnection] = duplexPair() + + await expect(Promise.all([ + noiseInit.secureOutbound(outboundConnection, { + remotePeer, + signal: abortController.signal + }), + noiseResp.secureInbound(inboundConnection, { + remotePeer: localPeer + }) + ])).to.eventually.be.rejected + .with.property('name', 'AbortError') + }) })