From 6f7e4852d2934eee938806087a45a5a6b3639fc7 Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Thu, 14 Sep 2023 10:36:28 +0100 Subject: [PATCH] [Yamux] Send whole data if window size is > 0 (#319) --- .../io/libp2p/core/mux/StreamMuxerProtocol.kt | 2 +- .../io/libp2p/mux/yamux/YamuxHandler.kt | 86 +++++++++---------- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 15 ++-- 3 files changed, 47 insertions(+), 56 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt index bb893970..3f7f460a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt @@ -22,7 +22,7 @@ fun interface StreamMuxerProtocol { } /** - * @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection before termination + * @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection */ @JvmStatic @JvmOverloads diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index ffe3ade9..e4c1fac3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -43,26 +43,17 @@ open class YamuxHandler( } fun flush(windowSize: AtomicInteger) { - while (!bufferedData.isEmpty()) { - val data = bufferedData.first() + while (!bufferedData.isEmpty() && windowSize.get() > 0) { + val data = bufferedData.removeFirst() val length = data.readableBytes() - if (length <= windowSize.get()) { - sendFrames(ctx, data, windowSize, id) - bufferedData.removeFirst() - } else { - // partial write to fit within window - val toRead = windowSize.get() - if (toRead > 0) { - val partialData = data.readRetainedSlice(toRead) - sendFrames(ctx, partialData, windowSize, id) - } - break - } + windowSize.addAndGet(-length) + val frame = YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data) + ctx.writeAndFlush(frame) } } fun close() { - bufferedData.forEach { it.release() } + bufferedData.forEach { releaseMessage(it) } bufferedData.clear() } } @@ -99,8 +90,7 @@ open class YamuxHandler( if (size == 0) { return } - val windowSize = receiveWindowSizes[msg.id] - if (windowSize == null) { + val windowSize = receiveWindowSizes[msg.id] ?: run { releaseMessage(msg.data!!) throw Libp2pException("Unable to retrieve receive window size for ${msg.id}") } @@ -110,8 +100,8 @@ open class YamuxHandler( if (newWindow < INITIAL_WINDOW_SIZE / 2) { val delta = INITIAL_WINDOW_SIZE - newWindow windowSize.addAndGet(delta) - ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) - ctx.flush() + val frame = YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()) + ctx.writeAndFlush(frame) } childRead(msg.id, msg.data!!) } @@ -150,41 +140,45 @@ open class YamuxHandler( override fun onChildWrite(child: MuxChannel, data: ByteBuf) { val ctx = getChannelHandlerContext() - val windowSize = - sendWindowSizes[child.id] ?: throw Libp2pException("Unable to retrieve send window size for ${child.id}") - - if (windowSize.get() <= 0) { - // wait until the window is increased to send more data - val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id) } - buffer.add(data) - val totalBufferedWrites = calculateTotalBufferedWrites() - if (totalBufferedWrites > maxBufferedConnectionWrites) { - buffer.close() - throw Libp2pException( - "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${ - ctx.channel().id().asLongText() - }" - ) - } - return + val windowSize = sendWindowSizes[child.id] ?: run { + releaseMessage(data) + throw Libp2pException("Unable to retrieve receive send window size for ${child.id}") } - sendFrames(ctx, data, windowSize, child.id) + + sendData(ctx, data, windowSize, child.id) } private fun calculateTotalBufferedWrites(): Int { return sendBuffers.values.sumOf { it.bufferedBytes() } } - fun sendFrames(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) { - data.sliceMaxSize(minOf(windowSize.get(), maxFrameDataLength)) - .map { slicedData -> - val length = slicedData.readableBytes() - windowSize.addAndGet(-length) - YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData) - }.forEach { frame -> - ctx.write(frame) + private fun sendData(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) { + data.sliceMaxSize(maxFrameDataLength) + .forEach { slicedData -> + if (windowSize.get() > 0) { + val length = slicedData.readableBytes() + windowSize.addAndGet(-length) + val frame = YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData) + ctx.writeAndFlush(frame) + } else { + // wait until the window is increased to send + addToSendBuffer(id, data, ctx) + } } - ctx.flush() + } + + private fun addToSendBuffer(id: MuxId, data: ByteBuf, ctx: ChannelHandlerContext) { + val buffer = sendBuffers.getOrPut(id) { SendBuffer(id) } + buffer.add(data) + val totalBufferedWrites = calculateTotalBufferedWrites() + if (totalBufferedWrites > maxBufferedConnectionWrites) { + buffer.close() + throw Libp2pException( + "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${ + ctx.channel().id().asLongText() + }" + ) + } } override fun onLocalOpen(child: MuxChannel) { diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index c3a753cd..9f67cd7e 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -151,7 +151,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { } @Test - fun `buffered data should be partially sent if it does not fit within window`() { + fun `buffered data should not be sent if it does not fit within window`() { val handler = openStreamByLocal() val streamId = readFrameOrThrow().streamId @@ -176,19 +176,16 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, - 3 + 2 ) ) var frame = readFrameOrThrow() - // one message is fully received + // one message is received assertThat(frame.data).isEqualTo("1984") - frame = readFrameOrThrow() - // the other message is partially received - assertThat(frame.data).isEqualTo("19") - // need to wait for another window update to receive more data + // need to wait for another window update to send more data assertThat(readFrame()).isNull() - // sending window update to read the final part of the buffer + // sending window update ech.writeInbound( YamuxFrame( streamId.toMuxId(), @@ -198,7 +195,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { ) ) frame = readFrameOrThrow() - assertThat(frame.data).isEqualTo("84") + assertThat(frame.data).isEqualTo("1984") } @Test