diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt new file mode 100644 index 000000000..34f9a10d2 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt @@ -0,0 +1,34 @@ +package io.libp2p.mux.yamux + +import io.libp2p.mux.InvalidFrameMuxerException + +/** + * Contains all the permissible values for flags in the yamux protocol. + */ +enum class YamuxFlag(val intFlag: Int) { + SYN(1), + ACK(2), + FIN(4), + RST(8); + + val asSet: Set = setOf(this) + + companion object { + val NONE = emptySet() + + private val validFlagCombinations = mapOf( + 0 to NONE, + SYN.intFlag to SYN.asSet, + ACK.intFlag to ACK.asSet, + FIN.intFlag to FIN.asSet, + RST.intFlag to RST.asSet, + ) + + fun fromInt(flags: Int) = + validFlagCombinations[flags] ?: throw InvalidFrameMuxerException("Invalid Yamux flags value: $flags") + + fun Set.toInt() = this + .fold(0) { acc, flag -> acc or flag.intFlag } + .also { require(it in validFlagCombinations) { "Invalid Yamux flags combination: $this" } } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt deleted file mode 100644 index 85499d0dd..000000000 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt +++ /dev/null @@ -1,11 +0,0 @@ -package io.libp2p.mux.yamux - -/** - * Contains all the permissible values for flags in the yamux protocol. - */ -object YamuxFlags { - const val SYN = 1 - const val ACK = 2 - const val FIN = 4 - const val RST = 8 -} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt index 32bd32e6a..c35dcea88 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt @@ -1,6 +1,5 @@ package io.libp2p.mux.yamux -import io.libp2p.etc.types.toByteArray import io.libp2p.etc.util.netty.mux.MuxId import io.netty.buffer.ByteBuf import io.netty.buffer.DefaultByteBufHolder @@ -9,17 +8,16 @@ import io.netty.buffer.Unpooled /** * Contains the fields that comprise a yamux frame. * @param id the ID of the stream. - * @param flags the flags value for this frame. + * @param flags the flags for this frame. * @param length the length field for this frame. * @param data the data segment. */ -class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) : +class YamuxFrame(val id: MuxId, val type: YamuxType, val flags: Set, val length: Long, val data: ByteBuf? = null) : DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) { override fun toString(): String { - if (data == null) { - return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)" - } - return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})" + val dataString = if (data == null) "" else ", len=${data.readableBytes()}, $data" + val flagsString = flags.joinToString("+") + return "YamuxFrame(id=$id, type=$type, flags=$flagsString, length=$length$dataString)" } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt index d85696508..f2db941ec 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt @@ -1,6 +1,7 @@ package io.libp2p.mux.yamux import io.libp2p.core.ProtocolViolationException +import io.libp2p.mux.yamux.YamuxFlag.Companion.toInt import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext @@ -24,8 +25,8 @@ class YamuxFrameCodec( */ override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) { out.writeByte(0) // version - out.writeByte(msg.type) - out.writeShort(msg.flags) + out.writeByte(msg.type.intValue) + out.writeShort(msg.flags.toInt()) out.writeInt(msg.id.id.toInt()) out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt()) out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER) @@ -46,15 +47,17 @@ class YamuxFrameCodec( val readerIndex = msg.readerIndex() msg.readByte(); // version always 0 val type = msg.readUnsignedByte() + val yamuxType = YamuxType.fromInt(type.toInt()) val flags = msg.readUnsignedShort() val streamId = msg.readUnsignedInt() val length = msg.readUnsignedInt() val yamuxId = YamuxId(ctx.channel().id(), streamId) - if (type.toInt() != YamuxType.DATA) { + val yamuxFlags = YamuxFlag.fromInt(flags) + if (yamuxType != YamuxType.DATA) { val yamuxFrame = YamuxFrame( yamuxId, - type.toInt(), - flags, + yamuxType, + yamuxFlags, length ) out.add(yamuxFrame) @@ -74,8 +77,8 @@ class YamuxFrameCodec( data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed val yamuxFrame = YamuxFrame( yamuxId, - type.toInt(), - flags, + yamuxType, + yamuxFlags, length, data ) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 659c40ab6..65339c57f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -53,6 +53,7 @@ open class YamuxHandler( when (msg.type) { YamuxType.DATA -> handleDataRead(msg) YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) + else -> { /* ignore */ } } } @@ -67,7 +68,7 @@ open class YamuxHandler( if (newWindow < initialWindowSize / 2) { val delta = initialWindowSize - newWindow receiveWindowSize.addAndGet(delta) - writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.NONE, delta.toLong())) } childRead(msg.id, msg.data!!) } @@ -80,14 +81,14 @@ open class YamuxHandler( } private fun handleFlags(msg: YamuxFrame) { - when (msg.flags) { - YamuxFlags.SYN -> { + when { + YamuxFlag.SYN in msg.flags -> { // ACK the new stream - writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0)) } - YamuxFlags.FIN -> onRemoteDisconnect(msg.id) - YamuxFlags.RST -> onRemoteClose(msg.id) + YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id) + YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id) } } @@ -109,11 +110,11 @@ open class YamuxHandler( data.sliceMaxSize(maxFrameDataLength) .forEach { slicedData -> val length = slicedData.readableBytes() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.NONE, length.toLong(), slicedData)) } if (closedForWriting && sendBuffer.readableBytes() == 0) { - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.FIN.asSet, 0)) } } @@ -126,7 +127,7 @@ open class YamuxHandler( } fun onLocalOpen() { - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.SYN.asSet, 0)) } fun onRemoteOpen() { @@ -141,7 +142,7 @@ open class YamuxHandler( fun onLocalClose() { // close stream immediately so not transferring buffered data sendBuffer.dispose() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.RST, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.RST.asSet, 0)) } } @@ -181,7 +182,7 @@ open class YamuxHandler( YamuxType.PING -> handlePing(msg) YamuxType.GO_AWAY -> handleGoAway(msg) else -> { - if (msg.flags == YamuxFlags.SYN) { + if (YamuxFlag.SYN in msg.flags) { // remote opens a new stream validateSynRemoteMuxId(msg.id) onRemoteYamuxOpen(msg.id) @@ -247,17 +248,15 @@ open class YamuxHandler( if (msg.id.id != YamuxId.SESSION_STREAM_ID) { throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}") } - when (msg.flags) { - YamuxFlags.SYN -> writeAndFlushFrame( + if (YamuxFlag.SYN in msg.flags) { + writeAndFlushFrame( YamuxFrame( YamuxId.sessionId(msg.id.parentId), YamuxType.PING, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, msg.length ) ) - - YamuxFlags.ACK -> {} } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt index 0746c8cf8..db779e7f9 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt @@ -1,11 +1,20 @@ package io.libp2p.mux.yamux +import io.libp2p.mux.InvalidFrameMuxerException + /** * Contains all the permissible values for types in the yamux protocol. */ -object YamuxType { - const val DATA = 0 - const val WINDOW_UPDATE = 1 - const val PING = 2 - const val GO_AWAY = 3 +enum class YamuxType(val intValue: Int) { + DATA(0), + WINDOW_UPDATE(1), + PING(2), + GO_AWAY(3); + + companion object { + private val intToTypeCache = values().associateBy { it.intValue } + + fun fromInt(intValue: Int): YamuxType = + intToTypeCache[intValue] ?: throw InvalidFrameMuxerException("Invalid Yamux type value: $intValue") + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 8b7218dca..b85e95733 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -46,20 +46,20 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override fun writeFrame(frame: AbstractTestMuxFrame) { val muxId = frame.streamId.toMuxId() val yamuxFrame = when (frame.flag) { - Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) + Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.SYN.asSet, 0) Data -> { val data = frame.data.fromHex() YamuxFrame( muxId, YamuxType.DATA, - 0, + YamuxFlag.NONE, data.size.toLong(), data.toByteBuf(allocateBuf()) ) } - Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.FIN, 0) - Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.RST, 0) + Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.FIN.asSet, 0) + Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.RST.asSet, 0) } ech.writeInbound(yamuxFrame) } @@ -67,8 +67,8 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override fun readFrame(): AbstractTestMuxFrame? { val yamuxFrame = readYamuxFrame() if (yamuxFrame != null) { - when (yamuxFrame.flags) { - YamuxFlags.SYN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open) + when { + YamuxFlag.SYN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open) } val data = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: "" @@ -77,9 +77,9 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Data, data) } - when (yamuxFrame.flags) { - YamuxFlags.FIN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close) - YamuxFlags.RST -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset) + when { + YamuxFlag.FIN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close) + YamuxFlag.RST in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset) } } @@ -102,7 +102,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val ackFrame = readYamuxFrameOrThrow() // receives ack stream - assertThat(ackFrame.flags).isEqualTo(YamuxFlags.ACK) + assertThat(ackFrame.flags).containsExactly(YamuxFlag.ACK) assertThat(ackFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) closeStream(12) @@ -119,7 +119,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.DATA, - 0, + YamuxFlag.NONE, length.toLong(), "42".repeat(length).fromHex().toByteBuf(allocateBuf()) ) @@ -128,7 +128,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val windowUpdateFrame = readYamuxFrameOrThrow() // window frame is sent based on the new window - assertThat(windowUpdateFrame.flags).isZero() + assertThat(windowUpdateFrame.flags).isEmpty() assertThat(windowUpdateFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) assertThat(windowUpdateFrame.length).isEqualTo(length.toLong()) } @@ -142,7 +142,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -151,7 +151,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(readFrame()).isNull() - ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000)) + ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 5000)) val frame = readFrameOrThrow() assertThat(frame.data).isEqualTo("1984") } @@ -165,7 +165,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -181,7 +181,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 2 ) ) @@ -196,7 +196,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 1 ) ) @@ -207,7 +207,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 10000 ) ) @@ -224,7 +224,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -245,7 +245,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { .hasMessage("Overflowed send buffer (612/512). Last stream attempting to write: $muxId") val frame = readYamuxFrameOrThrow() - assertThat(frame.flags).isEqualTo(YamuxFlags.RST) + assertThat(frame.flags).containsExactly(YamuxFlag.RST) } @Test @@ -261,7 +261,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, it.toLong() ) ) @@ -308,7 +308,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( id.toMuxId(), YamuxType.PING, - YamuxFlags.SYN, + YamuxFlag.SYN.asSet, // opaque value, echoed back 3 ) @@ -316,7 +316,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val pingFrame = readYamuxFrameOrThrow() - assertThat(pingFrame.flags).isEqualTo(YamuxFlags.ACK) + assertThat(pingFrame.flags).containsExactly(YamuxFlag.ACK) assertThat(pingFrame.type).isEqualTo(YamuxType.PING) assertThat(pingFrame.length).isEqualTo(3) } @@ -328,7 +328,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( id.toMuxId(), YamuxType.GO_AWAY, - 0, + YamuxFlag.NONE, // normal termination 0x2 ) @@ -374,7 +374,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -10 ) ) @@ -396,7 +396,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, initialWindowSize.toLong() ) ) @@ -434,7 +434,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, initialWindowSize.toLong() ) ) @@ -445,7 +445,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { msgPart3.data!!.release() val closeFrame = readYamuxFrameOrThrow() - assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN) + assertThat(closeFrame.flags).containsExactly(YamuxFlag.FIN) assertThat(closeFrame.length).isEqualTo(0L) assertThat(closeFrame.data).isNull() }