Skip to content

Commit

Permalink
[Yamux] Send whole data if window size is > 0 (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov authored Sep 14, 2023
1 parent a735c39 commit 6f7e485
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fun interface StreamMuxerProtocol {
}

/**
* @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection before termination
* @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection
*/
@JvmStatic
@JvmOverloads
Expand Down
86 changes: 40 additions & 46 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,17 @@ open class YamuxHandler(
}

fun flush(windowSize: AtomicInteger) {
while (!bufferedData.isEmpty()) {
val data = bufferedData.first()
while (!bufferedData.isEmpty() && windowSize.get() > 0) {
val data = bufferedData.removeFirst()
val length = data.readableBytes()
if (length <= windowSize.get()) {
sendFrames(ctx, data, windowSize, id)
bufferedData.removeFirst()
} else {
// partial write to fit within window
val toRead = windowSize.get()
if (toRead > 0) {
val partialData = data.readRetainedSlice(toRead)
sendFrames(ctx, partialData, windowSize, id)
}
break
}
windowSize.addAndGet(-length)
val frame = YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data)
ctx.writeAndFlush(frame)
}
}

fun close() {
bufferedData.forEach { it.release() }
bufferedData.forEach { releaseMessage(it) }
bufferedData.clear()
}
}
Expand Down Expand Up @@ -99,8 +90,7 @@ open class YamuxHandler(
if (size == 0) {
return
}
val windowSize = receiveWindowSizes[msg.id]
if (windowSize == null) {
val windowSize = receiveWindowSizes[msg.id] ?: run {
releaseMessage(msg.data!!)
throw Libp2pException("Unable to retrieve receive window size for ${msg.id}")
}
Expand All @@ -110,8 +100,8 @@ open class YamuxHandler(
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE - newWindow
windowSize.addAndGet(delta)
ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()))
ctx.flush()
val frame = YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())
ctx.writeAndFlush(frame)
}
childRead(msg.id, msg.data!!)
}
Expand Down Expand Up @@ -150,41 +140,45 @@ open class YamuxHandler(
override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()

val windowSize =
sendWindowSizes[child.id] ?: throw Libp2pException("Unable to retrieve send window size for ${child.id}")

if (windowSize.get() <= 0) {
// wait until the window is increased to send more data
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id) }
buffer.add(data)
val totalBufferedWrites = calculateTotalBufferedWrites()
if (totalBufferedWrites > maxBufferedConnectionWrites) {
buffer.close()
throw Libp2pException(
"Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${
ctx.channel().id().asLongText()
}"
)
}
return
val windowSize = sendWindowSizes[child.id] ?: run {
releaseMessage(data)
throw Libp2pException("Unable to retrieve receive send window size for ${child.id}")
}
sendFrames(ctx, data, windowSize, child.id)

sendData(ctx, data, windowSize, child.id)
}

private fun calculateTotalBufferedWrites(): Int {
return sendBuffers.values.sumOf { it.bufferedBytes() }
}

fun sendFrames(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) {
data.sliceMaxSize(minOf(windowSize.get(), maxFrameDataLength))
.map { slicedData ->
val length = slicedData.readableBytes()
windowSize.addAndGet(-length)
YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)
}.forEach { frame ->
ctx.write(frame)
private fun sendData(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) {
data.sliceMaxSize(maxFrameDataLength)
.forEach { slicedData ->
if (windowSize.get() > 0) {
val length = slicedData.readableBytes()
windowSize.addAndGet(-length)
val frame = YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)
ctx.writeAndFlush(frame)
} else {
// wait until the window is increased to send
addToSendBuffer(id, data, ctx)
}
}
ctx.flush()
}

private fun addToSendBuffer(id: MuxId, data: ByteBuf, ctx: ChannelHandlerContext) {
val buffer = sendBuffers.getOrPut(id) { SendBuffer(id) }
buffer.add(data)
val totalBufferedWrites = calculateTotalBufferedWrites()
if (totalBufferedWrites > maxBufferedConnectionWrites) {
buffer.close()
throw Libp2pException(
"Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${
ctx.channel().id().asLongText()
}"
)
}
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
Expand Down
15 changes: 6 additions & 9 deletions libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
}

@Test
fun `buffered data should be partially sent if it does not fit within window`() {
fun `buffered data should not be sent if it does not fit within window`() {
val handler = openStreamByLocal()
val streamId = readFrameOrThrow().streamId

Expand All @@ -176,19 +176,16 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
3
2
)
)

var frame = readFrameOrThrow()
// one message is fully received
// one message is received
assertThat(frame.data).isEqualTo("1984")
frame = readFrameOrThrow()
// the other message is partially received
assertThat(frame.data).isEqualTo("19")
// need to wait for another window update to receive more data
// need to wait for another window update to send more data
assertThat(readFrame()).isNull()
// sending window update to read the final part of the buffer
// sending window update
ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
Expand All @@ -198,7 +195,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
)
)
frame = readFrameOrThrow()
assertThat(frame.data).isEqualTo("84")
assertThat(frame.data).isEqualTo("1984")
}

@Test
Expand Down

0 comments on commit 6f7e485

Please sign in to comment.