Skip to content

Commit

Permalink
Set topicID on outbound IHAVE and ignore inbound IHAVE for unknown to…
Browse files Browse the repository at this point in the history
…pic (#365)

Co-authored-by: Anton Nashatyrev <anton.nashatyrev@gmail.com>
  • Loading branch information
StefanBratanov and Nashatyrev authored May 22, 2024
1 parent d0552c7 commit 640cc5d
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
10 changes: 7 additions & 3 deletions libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ open class GossipRouter(
}

private fun handleIHave(msg: Rpc.ControlIHave, peer: PeerHandler) {
// we ignore IHAVE gossip for unknown topics
if (msg.hasTopicID() && !mesh.containsKey(msg.topicID)) {
return
}
val peerScore = score.score(peer.peerId)
// we ignore IHAVE gossip from any peer whose score is below the gossip threshold
if (peerScore < scoreParams.gossipThreshold) return
Expand Down Expand Up @@ -544,7 +548,7 @@ open class GossipRouter(

peers.shuffled(random)
.take(max((params.gossipFactor * peers.size).toInt(), params.DLazy))
.forEach { enqueueIhave(it, shuffledMessageIds) }
.forEach { enqueueIhave(it, shuffledMessageIds, topic) }
}

private fun graft(peer: PeerHandler, topic: Topic) {
Expand Down Expand Up @@ -587,8 +591,8 @@ open class GossipRouter(
private fun enqueueIwant(peer: PeerHandler, messageIds: List<MessageId>) =
pendingRpcParts.getQueue(peer).addIWants(messageIds)

private fun enqueueIhave(peer: PeerHandler, messageIds: List<MessageId>) =
pendingRpcParts.getQueue(peer).addIHaves(messageIds)
private fun enqueueIhave(peer: PeerHandler, messageIds: List<MessageId>, topic: Topic) =
pendingRpcParts.getQueue(peer).addIHaves(messageIds, topic)

data class AcceptRequestsWhitelistEntry(val whitelistedTill: Long, val messagesAccepted: Int = 0) {
fun incrementMessageCount() = AcceptRequestsWhitelistEntry(whitelistedTill, messagesAccepted + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import pubsub.pb.Rpc

interface GossipRpcPartsQueue : RpcPartsQueue {

fun addIHave(messageId: MessageId)
fun addIHaves(messageIds: Collection<MessageId>) = messageIds.forEach { addIHave(it) }
fun addIHave(messageId: MessageId, topic: Topic)
fun addIHaves(messageIds: Collection<MessageId>, topic: Topic) = messageIds.forEach { addIHave(it, topic) }
fun addIWant(messageId: MessageId)
fun addIWants(messageIds: Collection<MessageId>) = messageIds.forEach { addIWant(it) }

Expand All @@ -37,14 +37,13 @@ open class DefaultGossipRpcPartsQueue(
private val params: GossipParams
) : DefaultRpcPartsQueue(), GossipRpcPartsQueue {

protected data class IHavePart(val messageId: MessageId) : AbstractPart {
protected data class IHavePart(val messageId: MessageId, val topic: Topic) : AbstractPart {
override fun appendToBuilder(builder: Rpc.RPC.Builder) {
val ctrlBuilder = builder.controlBuilder
val iHaveBuilder = if (ctrlBuilder.ihaveBuilderList.isEmpty()) {
ctrlBuilder.addIhaveBuilder()
} else {
ctrlBuilder.getIhaveBuilder(0)
}
val iHaveBuilder = ctrlBuilder.ihaveBuilderList
.find { it.topicID == topic }
?: ctrlBuilder.addIhaveBuilder().setTopicID(topic)

iHaveBuilder.addMessageIDs(messageId.toProtobuf())
}
}
Expand Down Expand Up @@ -82,8 +81,8 @@ open class DefaultGossipRpcPartsQueue(
}
}

override fun addIHave(messageId: MessageId) {
addPart(IHavePart(messageId))
override fun addIHave(messageId: MessageId, topic: Topic) {
addPart(IHavePart(messageId, topic))
}

override fun addIWant(messageId: MessageId) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.pubsub.gossip

import io.libp2p.pubsub.Topic
import io.libp2p.pubsub.gossip.builders.GossipParamsBuilder
import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder
import io.libp2p.tools.protobuf.RpcBuilder
Expand Down Expand Up @@ -35,6 +36,8 @@ class GossipRouterListLimitsTest {
private val routerWithLimits = GossipRouterBuilder(params = gossipParamsWithLimits).build()
private val routerWithNoLimits = GossipRouterBuilder(params = gossipParamsNoLimits).build()

private val topic: Topic = "topic1"

@Test
fun validateProtobufLists_validMessage() {
val msg = fullMsgBuilder().build()
Expand Down Expand Up @@ -96,7 +99,7 @@ class GossipRouterListLimitsTest {
@Test
fun validateProtobufLists_tooManyIHaves() {
val builder = fullMsgBuilder()
builder.addIHaves(maxIHaveLength, 1)
builder.addIHaves(maxIHaveLength, 1, topic)
val msg = builder.build()

Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse()
Expand All @@ -105,7 +108,7 @@ class GossipRouterListLimitsTest {
@Test
fun validateProtobufLists_tooManyIHaveMsgIds() {
val builder = fullMsgBuilder()
builder.addIHaves(1, maxIHaveLength)
builder.addIHaves(1, maxIHaveLength, topic)
val msg = builder.build()

Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse()
Expand Down Expand Up @@ -186,7 +189,7 @@ class GossipRouterListLimitsTest {
@Test
fun validateProtobufLists_maxIHaves() {
val builder = fullMsgBuilder()
builder.addIHaves(maxIHaveLength - 1, 1)
builder.addIHaves(maxIHaveLength - 1, 1, topic)
val msg = builder.build()

Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue()
Expand All @@ -195,7 +198,7 @@ class GossipRouterListLimitsTest {
@Test
fun validateProtobufLists_maxIHaveMsgIds() {
val builder = fullMsgBuilder()
builder.addIHaves(1, maxIHaveLength - 1)
builder.addIHaves(1, maxIHaveLength - 1, topic)
val msg = builder.build()

Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue()
Expand Down Expand Up @@ -256,7 +259,7 @@ class GossipRouterListLimitsTest {
// Add some data to all possible fields
builder.addSubscriptions(listSize)
builder.addPublishMessages(listSize, listSize)
builder.addIHaves(listSize, listSize)
builder.addIHaves(listSize, listSize, topic)
builder.addIWants(listSize, listSize)
builder.addGrafts(listSize)
builder.addPrunes(listSize, listSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.libp2p.pubsub.gossip
import io.libp2p.core.PeerId
import io.libp2p.etc.types.toProtobuf
import io.libp2p.etc.types.toWBytes
import io.libp2p.pubsub.Topic
import io.libp2p.pubsub.gossip.builders.GossipParamsBuilder
import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder
import org.assertj.core.api.Assertions.assertThat
Expand Down Expand Up @@ -49,7 +50,7 @@ class GossipRpcPartsQueueTest {
queue.addPublish(createRpcMessage("topic-$it", "data"))
}
(1..iHaves).forEach {
queue.addIHave(byteArrayOf(it.toByte()).toWBytes())
queue.addIHave(byteArrayOf(it.toByte()).toWBytes(), "topic-$it")
}
(1..iWants).forEach {
queue.addIWant(byteArrayOf(it.toByte()).toWBytes())
Expand Down Expand Up @@ -259,4 +260,50 @@ class GossipRpcPartsQueueTest {
assertThat(msgs).hasSize(3)
assertThat(msgs.merge()).isEqualTo(single)
}

@Test
fun `check that resulting IHAVE sets the topic ID`() {
val topic1: Topic = "topic1"
val messageId1 = "1111".toWBytes()
val topic2: Topic = "topic2"
val messageId2 = "2222".toWBytes()
val partsQueue = TestGossipQueue(gossipParamsWithLimits)
partsQueue.addIHave(messageId1, topic1)
partsQueue.addIHave(messageId2, topic2)
val res = partsQueue.takeMerged().first()

val serialized = res.toByteArray()
val deserializedRpc = Rpc.RPC.parseFrom(serialized)
assertThat(deserializedRpc.control.ihaveList).containsExactlyInAnyOrder(
Rpc.ControlIHave.newBuilder().setTopicID(topic1).addMessageIDs(messageId1.toProtobuf()).build(),
Rpc.ControlIHave.newBuilder().setTopicID(topic2).addMessageIDs(messageId2.toProtobuf()).build(),
)
}

@Test
fun `check that resulting IHAVE correctly groups topics`() {
val partsQueue = TestGossipQueue(gossipParamsWithLimits)

partsQueue.addIHave("1111".toWBytes(), "topic1")
partsQueue.addIHave("2222".toWBytes(), "topic2")
partsQueue.addIHave("3333".toWBytes(), "topic1")

val res = partsQueue.takeMerged().first()

val serialized = res.toByteArray()
val deserializedRpc = Rpc.RPC.parseFrom(serialized)
assertThat(deserializedRpc.control.ihaveList).containsExactlyInAnyOrder(
Rpc.ControlIHave.newBuilder()
.setTopicID("topic1")
.addAllMessageIDs(
listOf(
"1111".toWBytes().toProtobuf(),
"3333".toWBytes().toProtobuf()
)
).build(),
Rpc.ControlIHave.newBuilder()
.setTopicID("topic2")
.addMessageIDs("2222".toWBytes().toProtobuf()).build(),
)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.libp2p.tools.protobuf

import io.libp2p.etc.types.toProtobuf
import io.libp2p.pubsub.Topic
import pubsub.pb.Rpc
import kotlin.random.Random

Expand Down Expand Up @@ -28,9 +29,9 @@ class RpcBuilder {
}
}

fun addIHaves(iHaveCount: Int, messageIdCount: Int) {
fun addIHaves(iHaveCount: Int, messageIdCount: Int, topic: Topic) {
for (i in 0 until iHaveCount) {
val iHaveBuilder = Rpc.ControlIHave.newBuilder()
val iHaveBuilder = Rpc.ControlIHave.newBuilder().setTopicID(topic)
for (j in 0 until messageIdCount) {
iHaveBuilder.addMessageIDs(Random.nextBytes(6).toProtobuf())
}
Expand Down

0 comments on commit 640cc5d

Please sign in to comment.