From ee02cf94d5bcc467e1f116b4b012d916a19009a2 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 11 Oct 2023 15:37:48 +0300 Subject: [PATCH] Fix the case when a stream is closed while still having buffered data for write (#330) * Fix the case when a stream is closed while still having buffered data for write * Add unit test for close case when outbound data buffered --- .../kotlin/io/libp2p/etc/types/Delegates.kt | 16 +++++++ .../kotlin/io/libp2p/mux/MuxerException.kt | 1 + .../io/libp2p/mux/yamux/YamuxHandler.kt | 23 +++++++--- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 43 +++++++++++++++++++ 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt index ea44d904..a67a3c86 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt @@ -1,5 +1,6 @@ package io.libp2p.etc.types +import kotlin.properties.Delegates import kotlin.properties.ReadWriteProperty import kotlin.reflect.KProperty @@ -92,3 +93,18 @@ data class CappedValueDelegate>( } } } + +fun Delegates.writeOnce(initialValue: T): ReadWriteProperty = object : ReadWriteProperty { + private var value: T = initialValue + private var wasSet = false + + public override fun getValue(thisRef: Any?, property: KProperty<*>): T { + return value + } + + public override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { + if (wasSet) throw IllegalStateException("Property ${property.name} cannot be set more than once.") + this.value = value + wasSet = true + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt index 1ba4eaa1..b156aaf3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt @@ -13,3 +13,4 @@ class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream w class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null) class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null) +class ClosedForWritingMuxerException(muxId: MuxId) : WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index ece12d4e..659c40ab 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.types.writeOnce import io.libp2p.etc.util.netty.ByteBufQueue import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.ClosedForWritingMuxerException import io.libp2p.mux.InvalidFrameMuxerException import io.libp2p.mux.MuxHandler import io.libp2p.mux.UnknownStreamIdMuxerException @@ -19,6 +21,7 @@ import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import kotlin.math.max +import kotlin.properties.Delegates const val INITIAL_WINDOW_SIZE = 256 * 1024 const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB @@ -39,6 +42,7 @@ open class YamuxHandler( val sendWindowSize = AtomicInteger(initialWindowSize) val receiveWindowSize = AtomicInteger(initialWindowSize) val sendBuffer = ByteBufQueue() + var closedForWriting by Delegates.writeOnce(false) fun dispose() { sendBuffer.dispose() @@ -72,7 +76,7 @@ open class YamuxHandler( val delta = msg.length.toInt() sendWindowSize.addAndGet(delta) // try to send any buffered messages after the window update - drainBuffer() + drainBufferAndMaybeClose() } private fun handleFlags(msg: YamuxFrame) { @@ -98,7 +102,7 @@ open class YamuxHandler( } } - private fun drainBuffer() { + private fun drainBufferAndMaybeClose() { val maxSendLength = max(0, sendWindowSize.get()) val data = sendBuffer.take(maxSendLength) sendWindowSize.addAndGet(-data.readableBytes()) @@ -107,11 +111,18 @@ open class YamuxHandler( val length = slicedData.readableBytes() writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) } + + if (closedForWriting && sendBuffer.readableBytes() == 0) { + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + } } fun sendData(data: ByteBuf) { + if (closedForWriting) { + throw ClosedForWritingMuxerException(id) + } fillBuffer(data) - drainBuffer() + drainBufferAndMaybeClose() } fun onLocalOpen() { @@ -123,10 +134,8 @@ open class YamuxHandler( } fun onLocalDisconnect() { - // TODO: this implementation drops remaining data - drainBuffer() - sendBuffer.dispose() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + closedForWriting = true + drainBufferAndMaybeClose() } fun onLocalClose() { diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 5f239f08..8b7218dc 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -407,6 +407,49 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { msgPart3.data!!.release() } + @Test + fun `local close for writing should flush buffered data and send close frame on writeWindow update`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf()) + // writing a message which is larger than sendWindowSize + handler.ctx.writeAndFlush(msg) + + val msgPart1 = readYamuxFrameOrThrow() + assertThat(msgPart1.length).isEqualTo(256L) + assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256) + msgPart1.data!!.release() + + val msgPart2 = readYamuxFrameOrThrow() + assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256) + assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256) + msgPart2.data!!.release() + + // locally close for writing while some outbound data is still buffered + handler.ctx.disconnect() + + // ACKing message receive + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + initialWindowSize.toLong() + ) + ) + + val msgPart3 = readYamuxFrameOrThrow() + assertThat(msgPart3.length).isEqualTo(1L) + assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1) + msgPart3.data!!.release() + + val closeFrame = readYamuxFrameOrThrow() + assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN) + assertThat(closeFrame.length).isEqualTo(0L) + assertThat(closeFrame.data).isNull() + } + companion object { private fun YamuxStreamIdGenerator.toIterator() = iterator { while (true) {