Skip to content

Commit

Permalink
Revert "Treat padding as unknown TLV"
Browse files Browse the repository at this point in the history
This reverts commit c469ac4.
  • Loading branch information
thomash-acinq committed Apr 13, 2023
1 parent c619eaf commit 1f0d1dd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ object OnionMessages {
timeout: FiniteDuration,
maxAttempts: Int)

case class IntermediateNode(nodeId: PublicKey, customTlvs: Set[GenericTlv] = Set.empty)
case class IntermediateNode(nodeId: PublicKey, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty)

// @formatter:off
sealed trait Destination
case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], customTlvs: Set[GenericTlv] = Set.empty) extends Destination
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination
// @formatter:on

private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], nextTlvs: Set[RouteBlindingEncryptedDataTlv]): Seq[ByteVector] = {
if (intermediateNodes.isEmpty) {
Nil
} else {
(intermediateNodes.tail.map(node => Set[RouteBlindingEncryptedDataTlv](OutgoingNodeId(node.nodeId))) :+ nextTlvs)
.zip(intermediateNodes).map { case (tlvs, hop) => TlvStream(tlvs, hop.customTlvs) }
(intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ nextTlvs)
.zip(intermediateNodes).map { case (tlvs, hop) => TlvStream(hop.padding.map(Padding).toSet[RouteBlindingEncryptedDataTlv] ++ tlvs, hop.customTlvs) }
.map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes)
}
}
Expand All @@ -57,7 +57,7 @@ object OnionMessages {
intermediateNodes: Seq[IntermediateNode],
recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = {
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(recipient.nodeId)))
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.pathId.map(PathId)).flatten
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object OfferPayment {
val blindedRoute = blindedRoutes(attemptNumber % blindedRoutes.length)
OnionMessages.BlindedPath(blindedRoute)
case Right(nodeId) =>
OnionMessages.Recipient(nodeId, None, Set.empty)
OnionMessages.Recipient(nodeId, None, None)
}
// TODO: Find a path made of channels as some nodes may refuse to relay messages to nodes with which they don't have a channel.
val intermediateNodesToRecipient = Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ sealed trait RouteBlindingEncryptedDataTlv extends Tlv

object RouteBlindingEncryptedDataTlv {

/** Some padding can be added to ensure all payloads are the same size to improve privacy. */
case class Padding(dummy: ByteVector) extends RouteBlindingEncryptedDataTlv

/** Id of the outgoing channel, used to identify the next node. */
case class OutgoingChannelId(shortChannelId: ShortChannelId) extends RouteBlindingEncryptedDataTlv

Expand Down Expand Up @@ -108,6 +111,7 @@ object RouteBlindingEncryptedDataCodecs {
import scodec.codecs._
import scodec.{Attempt, Codec, DecodeResult}

private val padding: Codec[Padding] = tlvField(bytes)
private val outgoingChannelId: Codec[OutgoingChannelId] = tlvField(shortchannelid)
private val outgoingNodeId: Codec[OutgoingNodeId] = fixedLengthTlvField(33, publicKey)
private val pathId: Codec[PathId] = tlvField(bytes)
Expand All @@ -117,6 +121,7 @@ object RouteBlindingEncryptedDataCodecs {
private val allowedFeatures: Codec[AllowedFeatures] = tlvField(featuresCodec)

private val encryptedDataTlvCodec = discriminated[RouteBlindingEncryptedDataTlv].by(varint)
.typecase(UInt64(1), padding)
.typecase(UInt64(2), outgoingChannelId)
.typecase(UInt64(4), outgoingNodeId)
.typecase(UInt64(6), pathId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class OnionMessagesSpec extends AnyFunSuite {
val messageForBob = TlvStream[RouteBlindingEncryptedDataTlv](OutgoingNodeId(carol.publicKey), NextBlinding(blindingOverride.publicKey))
val encodedForBob = blindedRouteDataCodec.encode(messageForBob).require.bytes
assert(encodedForBob == hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007082102989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f")
val messageForCarol = TlvStream(Set[RouteBlindingEncryptedDataTlv](OutgoingNodeId(dave.publicKey)), Set(GenericTlv(UInt64(1), hex"0000000000000000000000000000000000000000000000000000000000000000000000")))
val messageForCarol = TlvStream[RouteBlindingEncryptedDataTlv](Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OutgoingNodeId(dave.publicKey))
val encodedForCarol = blindedRouteDataCodec.encode(messageForCarol).require.bytes
assert(encodedForCarol == hex"012300000000000000000000000000000000000000000000000000000000000000000000000421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991")
val messageForDave = TlvStream[RouteBlindingEncryptedDataTlv](PathId(hex"01234567"))
Expand Down Expand Up @@ -108,7 +108,7 @@ class OnionMessagesSpec extends AnyFunSuite {
val onionForAlice = OnionMessage(blindingSecret.publicKey, packet)

// Building the onion with functions from `OnionMessages`
val replyPath = buildRoute(blindingOverride, IntermediateNode(carol.publicKey, Set(GenericTlv(UInt64(1), hex"0000000000000000000000000000000000000000000000000000000000000000000000"))) :: Nil, Recipient(dave.publicKey, pathId = Some(hex"01234567")))
val replyPath = buildRoute(blindingOverride, IntermediateNode(carol.publicKey, padding = Some(hex"0000000000000000000000000000000000000000000000000000000000000000000000")) :: Nil, Recipient(dave.publicKey, pathId = Some(hex"01234567")))
assert(replyPath == routeFromCarol)
val Right((_, message)) = buildMessage(randomKey(), sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, BlindedPath(replyPath), TlvStream.empty)
assert(message == onionForAlice)
Expand Down Expand Up @@ -192,7 +192,7 @@ class OnionMessagesSpec extends AnyFunSuite {
assert(Sphinx.computeSharedSecret(blindingKey, carol) == sharedSecret)
assert(Sphinx.mac(ByteVector("blinded_node_id".getBytes), sharedSecret) == ByteVector32(hex"02afb2187075c8af51488242194b44c02624785ccd6fd43b5796c68f3025bf88"))
val blindedCarol = PublicKey(hex"02f4f524562868a09d5f54fb956ade3fa51ef071d64d923e395cc6db5e290ec67b")
val blindedPayload = TlvStream(Set[RouteBlindingEncryptedDataTlv](OutgoingNodeId(dave.publicKey)), Set(GenericTlv(UInt64(1), hex"0000000000000000000000000000000000000000000000000000000000000000000000")))
val blindedPayload = TlvStream[RouteBlindingEncryptedDataTlv](Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OutgoingNodeId(dave.publicKey))
val encodedBlindedPayload = blindedRouteDataCodec.encode(blindedPayload).require.bytes
assert(encodedBlindedPayload == hex"012300000000000000000000000000000000000000000000000000000000000000000000000421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991")
val blindedRoute = Sphinx.RouteBlinding.create(blindingSecret, carol.publicKey :: Nil, encodedBlindedPayload :: Nil).route
Expand Down Expand Up @@ -283,14 +283,20 @@ class OnionMessagesSpec extends AnyFunSuite {
).flatten
}

def makeRecipient(nodeKey: PrivateKey, json: JValue): Recipient =
Recipient(nodeKey.publicKey, Some(ByteVector.fromValidHex((json \ "path_id").extract[String])), (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json))

def makeIntermediateNode(nodeKey: PrivateKey, json: JValue): IntermediateNode =
IntermediateNode(nodeKey.publicKey, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json))

val blindingSecretBob = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(1) \ "blinding_secret").extract[String]))
val pathId = ByteVector.fromValidHex(((testVector \ "generate" \ "hops")(3) \ "tlvs" \ "path_id").extract[String])
val pathBobToDave =
buildRoute(blindingSecretBob,
Seq(IntermediateNode(bob.publicKey, getCustomTlvs((testVector \ "generate" \ "hops")(1) \ "tlvs")), IntermediateNode(carol.publicKey, getCustomTlvs((testVector \ "generate" \ "hops")(2) \ "tlvs"))),
Recipient(dave.publicKey, Some(pathId), getCustomTlvs((testVector \ "generate" \ "hops")(3) \ "tlvs")))
Seq(makeIntermediateNode(bob, (testVector \ "generate" \ "hops")(1) \ "tlvs"), makeIntermediateNode(carol, (testVector \ "generate" \ "hops")(2) \ "tlvs")),
makeRecipient(dave, (testVector \ "generate" \ "hops")(3) \ "tlvs"))
val blindingSecretAlice = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(0) \ "blinding_secret").extract[String]))
val intermediateAlice = Seq(IntermediateNode(alice.publicKey, getCustomTlvs((testVector \ "generate" \ "hops")(0) \ "tlvs")))
val intermediateAlice = Seq(makeIntermediateNode(alice, (testVector \ "generate" \ "hops")(0) \ "tlvs"))
val Some(pathAliceToDave) = buildRouteFrom(alice, blindingSecretAlice, intermediateAlice, BlindedPath(pathBobToDave))

val expectedPath = BlindedRoute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app

val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash)
offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false))
val Postman.SendMessage(_, Recipient(recipientId, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
assert(recipientId == merchantKey.publicKey)
assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty)
val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs)
Expand All @@ -90,7 +90,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash)
offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false))
for (_ <- 1 to nodeParams.onionMessageConfig.maxAttempts) {
val Postman.SendMessage(_, Recipient(recipientId, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
assert(recipientId == merchantKey.publicKey)
assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty)
val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs)
Expand All @@ -111,7 +111,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app

val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash)
offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false))
val Postman.SendMessage(_, Recipient(recipientId, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage]
assert(recipientId == merchantKey.publicKey)
assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty)
val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs)
Expand Down
Loading

0 comments on commit 1f0d1dd

Please sign in to comment.