diff --git a/packages/transport/stream/src/__tests__/stream.test.ts b/packages/transport/stream/src/__tests__/stream.test.ts index 844405988..cabf3aefb 100644 --- a/packages/transport/stream/src/__tests__/stream.test.ts +++ b/packages/transport/stream/src/__tests__/stream.test.ts @@ -814,6 +814,73 @@ describe("streams", function () { }); it("through relay if fails", async () => { + const dialFn = + streams[0].stream.components.connectionManager.openConnection.bind( + streams[0].stream.components.connectionManager + ); + + let directlyDialded = false; + const filteredDial = (address: PeerId | Multiaddr | Multiaddr[]) => { + if ( + isPeerId(address) && + address.toString() === streams[3].stream.peerIdStr + ) { + throw new Error("Mock fail"); // don't allow connect directly + } + + let addresses: Multiaddr[] = Array.isArray(address) + ? address + : [address as Multiaddr]; + for (const a of addresses) { + if ( + !a.protoNames().includes("p2p-circuit") && + a.toString().includes(streams[3].stream.peerIdStr) + ) { + throw new Error("Mock fail"); // don't allow connect directly + } + } + addresses = addresses.map((x) => + x.protoNames().includes("p2p-circuit") + ? multiaddr(x.toString().replace("/webrtc/", "/")) + : x + ); // TODO use webrtc in node + + directlyDialded = true; + return dialFn(addresses); + }; + + streams[0].stream.components.connectionManager.openConnection = + filteredDial; + expect(streams[0].stream.peers.size).toEqual(1); + await streams[0].stream.publish(data, { + to: [streams[3].stream.components.peerId] + }); + await waitFor(() => streams[3].received.length === 1); + await waitForResolved(() => expect(directlyDialded).toBeTrue()); + }); + + it("tries multiple relays", async () => { + await session.connect([[session.peers[1], session.peers[3]]]); + await waitForPeerStreams(streams[1].stream, streams[3].stream); + + /* + ┌───┐ + │ 0 │ + └┬─┬┘ + │┌▽┐ + ││1│ + │└┬┘ + ┌▽┐│ + │2││ + └┬┘│ + ┌▽─▽─┐ + │ 3 │ + └────┘ + + */ + + const dialedCircuitRelayAddresses: Set = new Set(); + const dialFn = streams[0].stream.components.connectionManager.openConnection.bind( streams[0].stream.components.connectionManager @@ -837,22 +904,32 @@ describe("streams", function () { throw new Error("Mock fail"); // don't allow connect directly } } - const q = 123; + addresses + .filter((x) => x.protoNames().includes("p2p-circuit")) + .forEach((x) => { + dialedCircuitRelayAddresses.add(x.toString()); + }); addresses = addresses.map((x) => - x.protoCodes().includes(281) + x.protoNames().includes("p2p-circuit") ? multiaddr(x.toString().replace("/webrtc/", "/")) : x ); // TODO use webrtc in node + + if (dialedCircuitRelayAddresses.size === 1) { + throw new Error("Mock fail"); // only succeed with the dial once we have tried two unique addresses (both neighbors) + } return dialFn(addresses); }; streams[0].stream.components.connectionManager.openConnection = filteredDial; + expect(streams[0].stream.peers.size).toEqual(1); await streams[0].stream.publish(data, { to: [streams[3].stream.components.peerId] }); await waitFor(() => streams[3].received.length === 1); + expect(dialedCircuitRelayAddresses.size).toEqual(2); }); }); diff --git a/packages/transport/stream/src/index.ts b/packages/transport/stream/src/index.ts index 9ea7bedb0..1c66889a7 100644 --- a/packages/transport/stream/src/index.ts +++ b/packages/transport/stream/src/index.ts @@ -1412,7 +1412,7 @@ export abstract class DirectStream< ) { // Dont await this even if it is async since this method can fail // and might take some time to run - this.maybeConnectDirectly(path).catch((e) => { + this.maybeConnectDirectly(to).catch((e) => { logger.error( "Failed to request direct connection: " + e.message ); @@ -1499,13 +1499,7 @@ export abstract class DirectStream< } } - async maybeConnectDirectly(path: string[]) { - if (path.length < 3) { - return; - } - - const toHash = path[path.length - 1]; - + async maybeConnectDirectly(toHash: string) { if (this.peers.has(toHash)) { return; // TODO, is this expected, or are we to dial more addresses? } @@ -1527,49 +1521,51 @@ export abstract class DirectStream< } // Connect through a closer relay that maybe does holepunch for us - const nextToHash = path[path.length - 2]; - const routeKey = nextToHash + toHash; - if (!this.recentDials.has(routeKey)) { - this.recentDials.add(routeKey); - const to = this.peerKeyHashToPublicKey.get(toHash)! as Ed25519PublicKey; - const toPeerId = await to.toPeerId(); - const addrs = this.multiaddrsMap.get(path[path.length - 2]); - if (addrs && addrs.length > 0) { - const addressesToDial = addrs.sort((a, b) => { - if (a.includes("/wss/")) { - if (b.includes("/wss/")) { - return 0; - } - return -1; - } - if (a.includes("/ws/")) { - if (b.includes("/ws/")) { - return 0; + const neighbours = this.routes.graph.neighbors(toHash); + outer: for (const neighbour of neighbours) { + const routeKey = neighbour + toHash; + if (!this.recentDials.has(routeKey)) { + this.recentDials.add(routeKey); + const to = this.peerKeyHashToPublicKey.get(toHash)! as Ed25519PublicKey; + const toPeerId = await to.toPeerId(); + const addrs = this.multiaddrsMap.get(neighbour); + if (addrs && addrs.length > 0) { + const addressesToDial = addrs.sort((a, b) => { + if (a.includes("/wss/")) { + if (b.includes("/wss/")) { + return 0; + } + return -1; } - if (b.includes("/wss/")) { - return 1; + if (a.includes("/ws/")) { + if (b.includes("/ws/")) { + return 0; + } + if (b.includes("/wss/")) { + return 1; + } + return -1; } - return -1; - } - return 0; - }); + return 0; + }); - for (const addr of addressesToDial) { - const circuitAddress = multiaddr( - addr + "/p2p-circuit/webrtc/p2p/" + toPeerId.toString() - ); - try { - await this.components.connectionManager.openConnection( - circuitAddress - ); - return; - } catch (error: any) { - logger.error( - "Failed to connect directly to: " + - circuitAddress.toString() + - ". " + - error?.message + for (const addr of addressesToDial) { + const circuitAddress = multiaddr( + addr + "/p2p-circuit/webrtc/p2p/" + toPeerId.toString() ); + try { + await this.components.connectionManager.openConnection( + circuitAddress + ); + break outer; // We succeeded! that means we dont have to try anymore + } catch (error: any) { + logger.warn( + "Failed to connect directly to: " + + circuitAddress.toString() + + ". " + + error?.message + ); + } } } }