diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt
new file mode 100644
index 000000000..34f9a10d2
--- /dev/null
+++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt
@@ -0,0 +1,34 @@
+package io.libp2p.mux.yamux
+
+import io.libp2p.mux.InvalidFrameMuxerException
+
+/**
+ * Contains all the permissible values for flags in the yamux
protocol.
+ */
+enum class YamuxFlag(val intFlag: Int) {
+ SYN(1),
+ ACK(2),
+ FIN(4),
+ RST(8);
+
+ val asSet: Set = setOf(this)
+
+ companion object {
+ val NONE = emptySet()
+
+ private val validFlagCombinations = mapOf(
+ 0 to NONE,
+ SYN.intFlag to SYN.asSet,
+ ACK.intFlag to ACK.asSet,
+ FIN.intFlag to FIN.asSet,
+ RST.intFlag to RST.asSet,
+ )
+
+ fun fromInt(flags: Int) =
+ validFlagCombinations[flags] ?: throw InvalidFrameMuxerException("Invalid Yamux flags value: $flags")
+
+ fun Set.toInt() = this
+ .fold(0) { acc, flag -> acc or flag.intFlag }
+ .also { require(it in validFlagCombinations) { "Invalid Yamux flags combination: $this" } }
+ }
+}
diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt
deleted file mode 100644
index 85499d0dd..000000000
--- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt
+++ /dev/null
@@ -1,11 +0,0 @@
-package io.libp2p.mux.yamux
-
-/**
- * Contains all the permissible values for flags in the yamux
protocol.
- */
-object YamuxFlags {
- const val SYN = 1
- const val ACK = 2
- const val FIN = 4
- const val RST = 8
-}
diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
index 32bd32e6a..c35dcea88 100644
--- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
+++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
@@ -1,6 +1,5 @@
package io.libp2p.mux.yamux
-import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.util.netty.mux.MuxId
import io.netty.buffer.ByteBuf
import io.netty.buffer.DefaultByteBufHolder
@@ -9,17 +8,16 @@ import io.netty.buffer.Unpooled
/**
* Contains the fields that comprise a yamux frame.
* @param id the ID of the stream.
- * @param flags the flags value for this frame.
+ * @param flags the flags for this frame.
* @param length the length field for this frame.
* @param data the data segment.
*/
-class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) :
+class YamuxFrame(val id: MuxId, val type: YamuxType, val flags: Set, val length: Long, val data: ByteBuf? = null) :
DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) {
override fun toString(): String {
- if (data == null) {
- return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)"
- }
- return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})"
+ val dataString = if (data == null) "" else ", len=${data.readableBytes()}, $data"
+ val flagsString = flags.joinToString("+")
+ return "YamuxFrame(id=$id, type=$type, flags=$flagsString, length=$length$dataString)"
}
}
diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
index d85696508..f2db941ec 100644
--- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
+++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
@@ -1,6 +1,7 @@
package io.libp2p.mux.yamux
import io.libp2p.core.ProtocolViolationException
+import io.libp2p.mux.yamux.YamuxFlag.Companion.toInt
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
@@ -24,8 +25,8 @@ class YamuxFrameCodec(
*/
override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) {
out.writeByte(0) // version
- out.writeByte(msg.type)
- out.writeShort(msg.flags)
+ out.writeByte(msg.type.intValue)
+ out.writeShort(msg.flags.toInt())
out.writeInt(msg.id.id.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt())
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
@@ -46,15 +47,17 @@ class YamuxFrameCodec(
val readerIndex = msg.readerIndex()
msg.readByte(); // version always 0
val type = msg.readUnsignedByte()
+ val yamuxType = YamuxType.fromInt(type.toInt())
val flags = msg.readUnsignedShort()
val streamId = msg.readUnsignedInt()
val length = msg.readUnsignedInt()
val yamuxId = YamuxId(ctx.channel().id(), streamId)
- if (type.toInt() != YamuxType.DATA) {
+ val yamuxFlags = YamuxFlag.fromInt(flags)
+ if (yamuxType != YamuxType.DATA) {
val yamuxFrame = YamuxFrame(
yamuxId,
- type.toInt(),
- flags,
+ yamuxType,
+ yamuxFlags,
length
)
out.add(yamuxFrame)
@@ -74,8 +77,8 @@ class YamuxFrameCodec(
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val yamuxFrame = YamuxFrame(
yamuxId,
- type.toInt(),
- flags,
+ yamuxType,
+ yamuxFlags,
length,
data
)
diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
index 659c40ab6..65339c57f 100644
--- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
+++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
@@ -53,6 +53,7 @@ open class YamuxHandler(
when (msg.type) {
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
+ else -> { /* ignore */ }
}
}
@@ -67,7 +68,7 @@ open class YamuxHandler(
if (newWindow < initialWindowSize / 2) {
val delta = initialWindowSize - newWindow
receiveWindowSize.addAndGet(delta)
- writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()))
+ writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.NONE, delta.toLong()))
}
childRead(msg.id, msg.data!!)
}
@@ -80,14 +81,14 @@ open class YamuxHandler(
}
private fun handleFlags(msg: YamuxFrame) {
- when (msg.flags) {
- YamuxFlags.SYN -> {
+ when {
+ YamuxFlag.SYN in msg.flags -> {
// ACK the new stream
- writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
+ writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0))
}
- YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
- YamuxFlags.RST -> onRemoteClose(msg.id)
+ YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id)
+ YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id)
}
}
@@ -109,11 +110,11 @@ open class YamuxHandler(
data.sliceMaxSize(maxFrameDataLength)
.forEach { slicedData ->
val length = slicedData.readableBytes()
- writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
+ writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.NONE, length.toLong(), slicedData))
}
if (closedForWriting && sendBuffer.readableBytes() == 0) {
- writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
+ writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.FIN.asSet, 0))
}
}
@@ -126,7 +127,7 @@ open class YamuxHandler(
}
fun onLocalOpen() {
- writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0))
+ writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.SYN.asSet, 0))
}
fun onRemoteOpen() {
@@ -141,7 +142,7 @@ open class YamuxHandler(
fun onLocalClose() {
// close stream immediately so not transferring buffered data
sendBuffer.dispose()
- writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.RST, 0))
+ writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.RST.asSet, 0))
}
}
@@ -181,7 +182,7 @@ open class YamuxHandler(
YamuxType.PING -> handlePing(msg)
YamuxType.GO_AWAY -> handleGoAway(msg)
else -> {
- if (msg.flags == YamuxFlags.SYN) {
+ if (YamuxFlag.SYN in msg.flags) {
// remote opens a new stream
validateSynRemoteMuxId(msg.id)
onRemoteYamuxOpen(msg.id)
@@ -247,17 +248,15 @@ open class YamuxHandler(
if (msg.id.id != YamuxId.SESSION_STREAM_ID) {
throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}")
}
- when (msg.flags) {
- YamuxFlags.SYN -> writeAndFlushFrame(
+ if (YamuxFlag.SYN in msg.flags) {
+ writeAndFlushFrame(
YamuxFrame(
YamuxId.sessionId(msg.id.parentId),
YamuxType.PING,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
msg.length
)
)
-
- YamuxFlags.ACK -> {}
}
}
diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
index 0746c8cf8..db779e7f9 100644
--- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
+++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
@@ -1,11 +1,20 @@
package io.libp2p.mux.yamux
+import io.libp2p.mux.InvalidFrameMuxerException
+
/**
* Contains all the permissible values for types in the yamux
protocol.
*/
-object YamuxType {
- const val DATA = 0
- const val WINDOW_UPDATE = 1
- const val PING = 2
- const val GO_AWAY = 3
+enum class YamuxType(val intValue: Int) {
+ DATA(0),
+ WINDOW_UPDATE(1),
+ PING(2),
+ GO_AWAY(3);
+
+ companion object {
+ private val intToTypeCache = values().associateBy { it.intValue }
+
+ fun fromInt(intValue: Int): YamuxType =
+ intToTypeCache[intValue] ?: throw InvalidFrameMuxerException("Invalid Yamux type value: $intValue")
+ }
}
diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
index 8b7218dca..b85e95733 100644
--- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
+++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
@@ -46,20 +46,20 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
override fun writeFrame(frame: AbstractTestMuxFrame) {
val muxId = frame.streamId.toMuxId()
val yamuxFrame = when (frame.flag) {
- Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0)
+ Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.SYN.asSet, 0)
Data -> {
val data = frame.data.fromHex()
YamuxFrame(
muxId,
YamuxType.DATA,
- 0,
+ YamuxFlag.NONE,
data.size.toLong(),
data.toByteBuf(allocateBuf())
)
}
- Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.FIN, 0)
- Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.RST, 0)
+ Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.FIN.asSet, 0)
+ Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.RST.asSet, 0)
}
ech.writeInbound(yamuxFrame)
}
@@ -67,8 +67,8 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
override fun readFrame(): AbstractTestMuxFrame? {
val yamuxFrame = readYamuxFrame()
if (yamuxFrame != null) {
- when (yamuxFrame.flags) {
- YamuxFlags.SYN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open)
+ when {
+ YamuxFlag.SYN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open)
}
val data = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: ""
@@ -77,9 +77,9 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Data, data)
}
- when (yamuxFrame.flags) {
- YamuxFlags.FIN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close)
- YamuxFlags.RST -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset)
+ when {
+ YamuxFlag.FIN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close)
+ YamuxFlag.RST in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset)
}
}
@@ -102,7 +102,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
val ackFrame = readYamuxFrameOrThrow()
// receives ack stream
- assertThat(ackFrame.flags).isEqualTo(YamuxFlags.ACK)
+ assertThat(ackFrame.flags).containsExactly(YamuxFlag.ACK)
assertThat(ackFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE)
closeStream(12)
@@ -119,7 +119,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.DATA,
- 0,
+ YamuxFlag.NONE,
length.toLong(),
"42".repeat(length).fromHex().toByteBuf(allocateBuf())
)
@@ -128,7 +128,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
val windowUpdateFrame = readYamuxFrameOrThrow()
// window frame is sent based on the new window
- assertThat(windowUpdateFrame.flags).isZero()
+ assertThat(windowUpdateFrame.flags).isEmpty()
assertThat(windowUpdateFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE)
assertThat(windowUpdateFrame.length).isEqualTo(length.toLong())
}
@@ -142,7 +142,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
-initialWindowSize.toLong()
)
)
@@ -151,7 +151,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
assertThat(readFrame()).isNull()
- ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000))
+ ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 5000))
val frame = readFrameOrThrow()
assertThat(frame.data).isEqualTo("1984")
}
@@ -165,7 +165,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
-initialWindowSize.toLong()
)
)
@@ -181,7 +181,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
2
)
)
@@ -196,7 +196,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
1
)
)
@@ -207,7 +207,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
10000
)
)
@@ -224,7 +224,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
-initialWindowSize.toLong()
)
)
@@ -245,7 +245,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
.hasMessage("Overflowed send buffer (612/512). Last stream attempting to write: $muxId")
val frame = readYamuxFrameOrThrow()
- assertThat(frame.flags).isEqualTo(YamuxFlags.RST)
+ assertThat(frame.flags).containsExactly(YamuxFlag.RST)
}
@Test
@@ -261,7 +261,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
it.toLong()
)
)
@@ -308,7 +308,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
id.toMuxId(),
YamuxType.PING,
- YamuxFlags.SYN,
+ YamuxFlag.SYN.asSet,
// opaque value, echoed back
3
)
@@ -316,7 +316,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
val pingFrame = readYamuxFrameOrThrow()
- assertThat(pingFrame.flags).isEqualTo(YamuxFlags.ACK)
+ assertThat(pingFrame.flags).containsExactly(YamuxFlag.ACK)
assertThat(pingFrame.type).isEqualTo(YamuxType.PING)
assertThat(pingFrame.length).isEqualTo(3)
}
@@ -328,7 +328,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
id.toMuxId(),
YamuxType.GO_AWAY,
- 0,
+ YamuxFlag.NONE,
// normal termination
0x2
)
@@ -374,7 +374,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
-10
)
)
@@ -396,7 +396,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
initialWindowSize.toLong()
)
)
@@ -434,7 +434,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
- YamuxFlags.ACK,
+ YamuxFlag.ACK.asSet,
initialWindowSize.toLong()
)
)
@@ -445,7 +445,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
msgPart3.data!!.release()
val closeFrame = readYamuxFrameOrThrow()
- assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN)
+ assertThat(closeFrame.flags).containsExactly(YamuxFlag.FIN)
assertThat(closeFrame.length).isEqualTo(0L)
assertThat(closeFrame.data).isNull()
}