Skip to content

Commit

Permalink
Refactor Yamux flags (#338)
Browse files Browse the repository at this point in the history
* Convert YamuxType to enum
* Refactor YamuxFlags: convert them to Set of enum values.
  • Loading branch information
Nashatyrev authored Oct 11, 2023
1 parent ee02cf9 commit e843836
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 74 deletions.
34 changes: 34 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.libp2p.mux.yamux

import io.libp2p.mux.InvalidFrameMuxerException

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
*/
enum class YamuxFlag(val intFlag: Int) {
SYN(1),
ACK(2),
FIN(4),
RST(8);

val asSet: Set<YamuxFlag> = setOf(this)

companion object {
val NONE = emptySet<YamuxFlag>()

private val validFlagCombinations = mapOf(
0 to NONE,
SYN.intFlag to SYN.asSet,
ACK.intFlag to ACK.asSet,
FIN.intFlag to FIN.asSet,
RST.intFlag to RST.asSet,
)

fun fromInt(flags: Int) =
validFlagCombinations[flags] ?: throw InvalidFrameMuxerException("Invalid Yamux flags value: $flags")

fun Set<YamuxFlag>.toInt() = this
.fold(0) { acc, flag -> acc or flag.intFlag }
.also { require(it in validFlagCombinations) { "Invalid Yamux flags combination: $this" } }
}
}
11 changes: 0 additions & 11 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt

This file was deleted.

12 changes: 5 additions & 7 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.libp2p.mux.yamux

import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.util.netty.mux.MuxId
import io.netty.buffer.ByteBuf
import io.netty.buffer.DefaultByteBufHolder
Expand All @@ -9,17 +8,16 @@ import io.netty.buffer.Unpooled
/**
* Contains the fields that comprise a yamux frame.
* @param id the ID of the stream.
* @param flags the flags value for this frame.
* @param flags the flags for this frame.
* @param length the length field for this frame.
* @param data the data segment.
*/
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) :
class YamuxFrame(val id: MuxId, val type: YamuxType, val flags: Set<YamuxFlag>, val length: Long, val data: ByteBuf? = null) :
DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) {

override fun toString(): String {
if (data == null) {
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)"
}
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})"
val dataString = if (data == null) "" else ", len=${data.readableBytes()}, $data"
val flagsString = flags.joinToString("+")
return "YamuxFrame(id=$id, type=$type, flags=$flagsString, length=$length$dataString)"
}
}
17 changes: 10 additions & 7 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.libp2p.mux.yamux

import io.libp2p.core.ProtocolViolationException
import io.libp2p.mux.yamux.YamuxFlag.Companion.toInt
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
Expand All @@ -24,8 +25,8 @@ class YamuxFrameCodec(
*/
override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) {
out.writeByte(0) // version
out.writeByte(msg.type)
out.writeShort(msg.flags)
out.writeByte(msg.type.intValue)
out.writeShort(msg.flags.toInt())
out.writeInt(msg.id.id.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt())
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
Expand All @@ -46,15 +47,17 @@ class YamuxFrameCodec(
val readerIndex = msg.readerIndex()
msg.readByte(); // version always 0
val type = msg.readUnsignedByte()
val yamuxType = YamuxType.fromInt(type.toInt())
val flags = msg.readUnsignedShort()
val streamId = msg.readUnsignedInt()
val length = msg.readUnsignedInt()
val yamuxId = YamuxId(ctx.channel().id(), streamId)
if (type.toInt() != YamuxType.DATA) {
val yamuxFlags = YamuxFlag.fromInt(flags)
if (yamuxType != YamuxType.DATA) {
val yamuxFrame = YamuxFrame(
yamuxId,
type.toInt(),
flags,
yamuxType,
yamuxFlags,
length
)
out.add(yamuxFrame)
Expand All @@ -74,8 +77,8 @@ class YamuxFrameCodec(
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val yamuxFrame = YamuxFrame(
yamuxId,
type.toInt(),
flags,
yamuxType,
yamuxFlags,
length,
data
)
Expand Down
31 changes: 15 additions & 16 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ open class YamuxHandler(
when (msg.type) {
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
else -> { /* ignore */ }
}
}

Expand All @@ -67,7 +68,7 @@ open class YamuxHandler(
if (newWindow < initialWindowSize / 2) {
val delta = initialWindowSize - newWindow
receiveWindowSize.addAndGet(delta)
writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()))
writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.NONE, delta.toLong()))
}
childRead(msg.id, msg.data!!)
}
Expand All @@ -80,14 +81,14 @@ open class YamuxHandler(
}

private fun handleFlags(msg: YamuxFrame) {
when (msg.flags) {
YamuxFlags.SYN -> {
when {
YamuxFlag.SYN in msg.flags -> {
// ACK the new stream
writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0))
}

YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id)
YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id)
}
}

Expand All @@ -109,11 +110,11 @@ open class YamuxHandler(
data.sliceMaxSize(maxFrameDataLength)
.forEach { slicedData ->
val length = slicedData.readableBytes()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.NONE, length.toLong(), slicedData))
}

if (closedForWriting && sendBuffer.readableBytes() == 0) {
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.FIN.asSet, 0))
}
}

Expand All @@ -126,7 +127,7 @@ open class YamuxHandler(
}

fun onLocalOpen() {
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0))
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.SYN.asSet, 0))
}

fun onRemoteOpen() {
Expand All @@ -141,7 +142,7 @@ open class YamuxHandler(
fun onLocalClose() {
// close stream immediately so not transferring buffered data
sendBuffer.dispose()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.RST, 0))
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.RST.asSet, 0))
}
}

Expand Down Expand Up @@ -181,7 +182,7 @@ open class YamuxHandler(
YamuxType.PING -> handlePing(msg)
YamuxType.GO_AWAY -> handleGoAway(msg)
else -> {
if (msg.flags == YamuxFlags.SYN) {
if (YamuxFlag.SYN in msg.flags) {
// remote opens a new stream
validateSynRemoteMuxId(msg.id)
onRemoteYamuxOpen(msg.id)
Expand Down Expand Up @@ -247,17 +248,15 @@ open class YamuxHandler(
if (msg.id.id != YamuxId.SESSION_STREAM_ID) {
throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}")
}
when (msg.flags) {
YamuxFlags.SYN -> writeAndFlushFrame(
if (YamuxFlag.SYN in msg.flags) {
writeAndFlushFrame(
YamuxFrame(
YamuxId.sessionId(msg.id.parentId),
YamuxType.PING,
YamuxFlags.ACK,
YamuxFlag.ACK.asSet,
msg.length
)
)

YamuxFlags.ACK -> {}
}
}

Expand Down
19 changes: 14 additions & 5 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package io.libp2p.mux.yamux

import io.libp2p.mux.InvalidFrameMuxerException

/**
* Contains all the permissible values for types in the <code>yamux</code> protocol.
*/
object YamuxType {
const val DATA = 0
const val WINDOW_UPDATE = 1
const val PING = 2
const val GO_AWAY = 3
enum class YamuxType(val intValue: Int) {
DATA(0),
WINDOW_UPDATE(1),
PING(2),
GO_AWAY(3);

companion object {
private val intToTypeCache = values().associateBy { it.intValue }

fun fromInt(intValue: Int): YamuxType =
intToTypeCache[intValue] ?: throw InvalidFrameMuxerException("Invalid Yamux type value: $intValue")
}
}
Loading

0 comments on commit e843836

Please sign in to comment.