Skip to content

Commit

Permalink
OutgoingNodeId in a blinded path may not be a wallet (#2970)
Browse files Browse the repository at this point in the history
`OutgoingNodeId` was assumed to be a wallet. While it probably shouldn't cause problems, it's better to keep both cases distinct.
  • Loading branch information
thomash-acinq authored Dec 19, 2024
1 parent c390560 commit 27ba60f
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket}
import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Features, InitFeature, Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee}
import fr.acinq.eclair.{EncodedNodeId, Features, InitFeature, Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee}

import java.util.UUID
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -142,7 +142,8 @@ class ChannelRelay private(nodeParams: NodeParams,
}

private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match {
case Left(walletNodeId) => (None, Some(walletNodeId))
case Left(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)) => (None, Some(walletNodeId))
case Left(_) => (None, None)
case Right(shortChannelId) => (Some(shortChannelId), None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object ChannelRelayer {
case Relay(channelRelayPacket, originNode) =>
val relayId = UUID.randomUUID()
val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match {
case Left(walletNodeId) => Some(walletNodeId)
case Left(outgoingNodeId) => Some(outgoingNodeId.publicKey)
case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId)
case None => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
)
paymentRelayData.outgoing match {
case Left(outgoingNodeId) =>
// The next node seems to be a wallet node directly connected to us.
validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
validateRelay(outgoingNodeId, nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
case Right(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId)
waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextPathKey, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs._
import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64}
import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, MilliSatoshi, ShortChannelId, UInt64}
import scodec.bits.{BitVector, ByteVector}

/**
Expand Down Expand Up @@ -228,7 +228,7 @@ object PaymentOnion {
sealed trait ChannelRelay extends IntermediatePayload {
// @formatter:off
/** The outgoing channel, or the nodeId of one of our peers. */
def outgoing: Either[PublicKey, ShortChannelId]
def outgoing: Either[EncodedNodeId.WithPublicKey, ShortChannelId]
def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi
def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ object BlindedRouteData {
}

case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) {
// This is usually a channel, unless the next node is a mobile wallet connected to our node.
val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match {
val outgoing: Either[EncodedNodeId.WithPublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match {
case Some(r) => Right(r.shortChannelId)
case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey)
case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey])
}
val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get
val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get
Expand All @@ -114,7 +113,6 @@ object BlindedRouteData {
}

def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = {
// Note that the BOLTs require using an OutgoingChannelId, but we optionally support using a node_id.
if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey]) return Left(ForbiddenTlv(UInt64(4)))
if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class PaymentOnionSpec extends AnyFunSuite {
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat),
)
val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey)
assert(payload.outgoing == Left(nextNodeId))
val Left(nodeId) = payload.outgoing
assert(nodeId.publicKey == nextNodeId)
assert(payload.amountToForward(10_000 msat) == 9990.msat)
assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856))
assert(payload.paymentRelayData.allowedFeatures.isEmpty)
Expand Down

0 comments on commit 27ba60f

Please sign in to comment.