Skip to content

Commit

Permalink
Fix the case when a stream is closed while still having buffered data…
Browse files Browse the repository at this point in the history
… for write (#330)

* Fix the case when a stream is closed while still having buffered data for write
* Add unit test for close case when outbound data buffered
  • Loading branch information
Nashatyrev authored Oct 11, 2023
1 parent 3ac83c4 commit ee02cf9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 7 deletions.
16 changes: 16 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.etc.types

import kotlin.properties.Delegates
import kotlin.properties.ReadWriteProperty
import kotlin.reflect.KProperty

Expand Down Expand Up @@ -92,3 +93,18 @@ data class CappedValueDelegate<C : Comparable<C>>(
}
}
}

fun <T : Any> Delegates.writeOnce(initialValue: T): ReadWriteProperty<Any?, T> = object : ReadWriteProperty<Any?, T> {
private var value: T = initialValue
private var wasSet = false

public override fun getValue(thisRef: Any?, property: KProperty<*>): T {
return value
}

public override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) {
if (wasSet) throw IllegalStateException("Property ${property.name} cannot be set more than once.")
this.value = value
wasSet = true
}
}
1 change: 1 addition & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream w
class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null)

class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null)
class ClosedForWritingMuxerException(muxId: MuxId) : WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null)
23 changes: 16 additions & 7 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.types.writeOnce
import io.libp2p.etc.util.netty.ByteBufQueue
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.ClosedForWritingMuxerException
import io.libp2p.mux.InvalidFrameMuxerException
import io.libp2p.mux.MuxHandler
import io.libp2p.mux.UnknownStreamIdMuxerException
Expand All @@ -19,6 +21,7 @@ import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import kotlin.math.max
import kotlin.properties.Delegates

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB
Expand All @@ -39,6 +42,7 @@ open class YamuxHandler(
val sendWindowSize = AtomicInteger(initialWindowSize)
val receiveWindowSize = AtomicInteger(initialWindowSize)
val sendBuffer = ByteBufQueue()
var closedForWriting by Delegates.writeOnce(false)

fun dispose() {
sendBuffer.dispose()
Expand Down Expand Up @@ -72,7 +76,7 @@ open class YamuxHandler(
val delta = msg.length.toInt()
sendWindowSize.addAndGet(delta)
// try to send any buffered messages after the window update
drainBuffer()
drainBufferAndMaybeClose()
}

private fun handleFlags(msg: YamuxFrame) {
Expand All @@ -98,7 +102,7 @@ open class YamuxHandler(
}
}

private fun drainBuffer() {
private fun drainBufferAndMaybeClose() {
val maxSendLength = max(0, sendWindowSize.get())
val data = sendBuffer.take(maxSendLength)
sendWindowSize.addAndGet(-data.readableBytes())
Expand All @@ -107,11 +111,18 @@ open class YamuxHandler(
val length = slicedData.readableBytes()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
}

if (closedForWriting && sendBuffer.readableBytes() == 0) {
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
}
}

fun sendData(data: ByteBuf) {
if (closedForWriting) {
throw ClosedForWritingMuxerException(id)
}
fillBuffer(data)
drainBuffer()
drainBufferAndMaybeClose()
}

fun onLocalOpen() {
Expand All @@ -123,10 +134,8 @@ open class YamuxHandler(
}

fun onLocalDisconnect() {
// TODO: this implementation drops remaining data
drainBuffer()
sendBuffer.dispose()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
closedForWriting = true
drainBufferAndMaybeClose()
}

fun onLocalClose() {
Expand Down
43 changes: 43 additions & 0 deletions libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,49 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
msgPart3.data!!.release()
}

@Test
fun `local close for writing should flush buffered data and send close frame on writeWindow update`() {
val handler = openStreamLocal()
val muxId = readFrameOrThrow().streamId.toMuxId()

val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf())
// writing a message which is larger than sendWindowSize
handler.ctx.writeAndFlush(msg)

val msgPart1 = readYamuxFrameOrThrow()
assertThat(msgPart1.length).isEqualTo(256L)
assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256)
msgPart1.data!!.release()

val msgPart2 = readYamuxFrameOrThrow()
assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256)
assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256)
msgPart2.data!!.release()

// locally close for writing while some outbound data is still buffered
handler.ctx.disconnect()

// ACKing message receive
ech.writeInbound(
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
initialWindowSize.toLong()
)
)

val msgPart3 = readYamuxFrameOrThrow()
assertThat(msgPart3.length).isEqualTo(1L)
assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1)
msgPart3.data!!.release()

val closeFrame = readYamuxFrameOrThrow()
assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN)
assertThat(closeFrame.length).isEqualTo(0L)
assertThat(closeFrame.data).isNull()
}

companion object {
private fun YamuxStreamIdGenerator.toIterator() = iterator {
while (true) {
Expand Down

0 comments on commit ee02cf9

Please sign in to comment.