Skip to content

Commit

Permalink
[Yamux] Fix sending of buffered messages after a window update
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Sep 6, 2023
1 parent 086952d commit 75e6aac
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 45 deletions.
106 changes: 61 additions & 45 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@ import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import org.slf4j.LoggerFactory
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024

private val log = LoggerFactory.getLogger(YamuxHandler::class.java)

open class YamuxHandler(
override val multistreamProtocol: MultistreamProtocol,
override val maxFrameDataLength: Int,
Expand All @@ -44,22 +41,29 @@ open class YamuxHandler(
var written = 0
while (!buffered.isEmpty()) {
val buf = buffered.first()
val readableBytes = buf.readableBytes()
if (readableBytes + written < sendWindow.get()) {
val bufLength = buf.readableBytes()
if (bufLength <= sendWindow.get()) {
sendBlocks(ctx, buf, sendWindow, id)
written += readableBytes
written += bufLength
buf.release()
buffered.removeFirst()
} else {
// partial write to fit within window
val toRead = sendWindow.get() - written
sendBlocks(ctx, buf.readSlice(toRead), sendWindow, id)
written += toRead
val toRead = sendWindow.get()
if (toRead > 0) {
val partialBuf = buf.readSlice(toRead)
sendBlocks(ctx, partialBuf, sendWindow, id)
written += toRead
}
break
}
}
return written
}

fun totalBytes(): Int {
return buffered.sumOf { it.readableBytes() }
}
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
Expand Down Expand Up @@ -88,33 +92,20 @@ open class YamuxHandler(
}
}

private fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> {
// ACK the new stream
onRemoteYamuxOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}

YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
}

private fun handleDataRead(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
val size = msg.length
handleFlags(msg)
if (size.toInt() == 0) {
val size = msg.length.toInt()
if (size == 0) {
return
}
val recWindow = receiveWindows[msg.id]
if (recWindow == null) {
releaseMessage(msg.data!!)
throw Libp2pException("No receive window for " + msg.id)
throw Libp2pException("No receive window for ${msg.id}")
}
val newWindow = recWindow.addAndGet(-size.toInt())
val newWindow = recWindow.addAndGet(-size)
// send a window update frame once half of the window is depleted
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE - newWindow
recWindow.addAndGet(delta)
Expand All @@ -126,21 +117,33 @@ open class YamuxHandler(

private fun handleWindowUpdate(msg: YamuxFrame) {
handleFlags(msg)
val size = msg.length.toInt()
if (size == 0) {
val delta = msg.length.toInt()
if (delta == 0) {
return
}
val sendWindow = sendWindows[msg.id] ?: return
sendWindow.addAndGet(size)
val buffer = sendBuffers[msg.id]
if (buffer != null) {
val writtenBytes = buffer.flush(sendWindow, msg.id)
totalBufferedWrites.addAndGet(-writtenBytes)
sendWindow.addAndGet(delta)
val buffer = sendBuffers[msg.id] ?: return
// try to send any buffered messages after the window update
val writtenBytes = buffer.flush(sendWindow, msg.id)
totalBufferedWrites.addAndGet(-writtenBytes)
}

private fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> {
onRemoteYamuxOpen(msg.id)
// ACK the new stream
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}

YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
}

private fun handleGoAway(msg: YamuxFrame) {
log.debug("Session will be terminated. Go Away message with with error code ${msg.length} has been received.")
onRemoteClose(msg.id)
}

Expand All @@ -162,10 +165,11 @@ open class YamuxHandler(
}

fun sendBlocks(ctx: ChannelHandlerContext, data: ByteBuf, sendWindow: AtomicInteger, id: MuxId) {
data.sliceMaxSize(minOf(maxFrameDataLength, sendWindow.get()))
.map { frameSliceBuf ->
sendWindow.addAndGet(-frameSliceBuf.readableBytes())
YamuxFrame(id, YamuxType.DATA, 0, frameSliceBuf.readableBytes().toLong(), frameSliceBuf)
data.sliceMaxSize(minOf(sendWindow.get(), maxFrameDataLength))
.map { slicedData ->
val length = slicedData.readableBytes()
sendWindow.addAndGet(-length)
YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)
}.forEach { muxFrame ->
ctx.write(muxFrame)
}
Expand All @@ -188,22 +192,34 @@ open class YamuxHandler(
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
val sendWindow = sendWindows.remove(child.id)
val buffered = sendBuffers.remove(child.id)
if (buffered != null && sendWindow != null) {
buffered.flush(sendWindow, child.id)
}
flushAndClearSendBuffers(child.id)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

private fun flushAndClearSendBuffers(id: MuxId) {
val sendWindow = sendWindows.remove(id)
val buffer = sendBuffers.remove(id)
if (buffer != null && sendWindow != null) {
val writtenBytes = buffer.flush(sendWindow, id)
totalBufferedWrites.addAndGet(-writtenBytes)
}
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
clearSendBuffers(child.id)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
}

override fun onChildClosed(child: MuxChannel<ByteBuf>) {
sendWindows.remove(child.id)
receiveWindows.remove(child.id)
sendBuffers.remove(child.id)
clearSendBuffers(child.id)
}

private fun clearSendBuffers(id: MuxId) {
sendBuffers.remove(id)?.let {
totalBufferedWrites.addAndGet(-it.totalBytes())
}
}

override fun generateNextId() =
Expand Down
35 changes: 35 additions & 0 deletions libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,41 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
assertThat(frame.data).isEqualTo("1984")
}

@Test
fun `partial data is written if it fits windows`() {
val handler = openStreamByLocal()
val streamId = readFrameOrThrow().streamId

ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
-INITIAL_WINDOW_SIZE.toLong()
)
)

val message = "1984".fromHex().toByteBuf(allocateBuf())
// 2 bytes per message
handler.ctx.writeAndFlush(message)
handler.ctx.writeAndFlush(message.copy())

assertThat(readFrame()).isNull()

ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
3
)
)

val frame = readFrameOrThrow()
// one message is fully received
assertThat(frame.data).isEqualTo("1984")
}

@Test
fun `test ping`() {
val id: Long = 0
Expand Down

0 comments on commit 75e6aac

Please sign in to comment.