Skip to content

Commit

Permalink
Make trampoline payments use per-channel fee and cltv (#1853)
Browse files Browse the repository at this point in the history
Trampoline payments used to ignore the fee and cltv set for the local channel and use a global default value instead. We now use the correct fee and cltv for the specific local channel that we take.
  • Loading branch information
thomash-acinq authored Jul 1, 2021
1 parent 85ed433 commit f857368
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ object NodeRelay {

/** Compute route params that honor our fee and cltv requirements. */
def computeRouteParams(nodeParams: NodeParams, amountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry): RouteParams = {
val routeMaxCltv = expiryIn - expiryOut - nodeParams.expiryDelta
val routeMaxFee = amountIn - amountOut - nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, amountOut)
val routeMaxCltv = expiryIn - expiryOut
val routeMaxFee = amountIn - amountOut
RouteCalculation.getDefaultRouteParams(nodeParams.routerConf).copy(
maxFeeBase = routeMaxFee,
routeMaxCltv = routeMaxCltv,
maxFeePct = 0 // we disable percent-based max fee calculation, we're only interested in collecting our node fee
maxFeePct = 0, // we disable percent-based max fee calculation, we're only interested in collecting our node fee
includeLocalChannelCost = true,
)
}

Expand Down
94 changes: 50 additions & 44 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,18 @@ object Graph {
* Yen's algorithm to find the k-shortest (loop-less) paths in a graph, uses dijkstra as search algo. Is guaranteed to
* terminate finding at most @pathsToFind paths sorted by cost (the cheapest is in position 0).
*
* @param graph the graph on which will be performed the search
* @param sourceNode the starting node of the path we're looking for (payer)
* @param targetNode the destination node of the path (recipient)
* @param amount amount to send to the last node
* @param ignoredEdges channels that should be avoided
* @param ignoredVertices nodes that should be avoided
* @param extraEdges additional edges that can be used (e.g. private channels from invoices)
* @param pathsToFind number of distinct paths to be returned
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param currentBlockHeight the height of the chain tip (latest block)
* @param boundaries a predicate function that can be used to impose limits on the outcome of the search
* @param graph the graph on which will be performed the search
* @param sourceNode the starting node of the path we're looking for (payer)
* @param targetNode the destination node of the path (recipient)
* @param amount amount to send to the last node
* @param ignoredEdges channels that should be avoided
* @param ignoredVertices nodes that should be avoided
* @param extraEdges additional edges that can be used (e.g. private channels from invoices)
* @param pathsToFind number of distinct paths to be returned
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param currentBlockHeight the height of the chain tip (latest block)
* @param boundaries a predicate function that can be used to impose limits on the outcome of the search
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
def yenKshortestPaths(graph: DirectedGraph,
sourceNode: PublicKey,
Expand All @@ -95,10 +96,11 @@ object Graph {
pathsToFind: Int,
wr: Option[WeightRatios],
currentBlockHeight: Long,
boundaries: RichWeight => Boolean): Seq[WeightedPath] = {
boundaries: RichWeight => Boolean,
includeLocalChannelCost: Boolean): Seq[WeightedPath] = {
// find the shortest path (k = 0)
val targetWeight = RichWeight(amount, 0, CltvExpiryDelta(0), 0)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
if (shortestPath.isEmpty) {
return Seq.empty // if we can't even find a single path, avoid returning a Seq(Seq.empty)
}
Expand All @@ -110,7 +112,7 @@ object Graph {

var allSpurPathsFound = false
val shortestPaths = new mutable.Queue[PathWithSpur]
shortestPaths.enqueue(PathWithSpur(WeightedPath(shortestPath, pathWeight(sourceNode, shortestPath, amount, currentBlockHeight, wr)), 0))
shortestPaths.enqueue(PathWithSpur(WeightedPath(shortestPath, pathWeight(sourceNode, shortestPath, amount, currentBlockHeight, wr, includeLocalChannelCost)), 0))
// stores the candidates for the k-th shortest path, sorted by path cost
val candidates = new mutable.PriorityQueue[PathWithSpur]

Expand All @@ -135,12 +137,12 @@ object Graph {
val alreadyExploredEdges = shortestPaths.collect { case p if p.p.path.takeRight(i) == rootPathEdges => p.p.path(p.p.path.length - 1 - i).desc }.toSet
// we also want to ignore any vertex on the root path to prevent loops
val alreadyExploredVertices = rootPathEdges.map(_.desc.b).toSet
val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr)
val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr, includeLocalChannelCost)
// find the "spur" path, a sub-path going from the spur node to the target avoiding previously found sub-paths
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr)
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
if (spurPath.nonEmpty) {
val completePath = spurPath ++ rootPathEdges
val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr))
val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr, includeLocalChannelCost))
candidates.enqueue(PathWithSpur(candidatePath, i))
}
}
Expand All @@ -163,16 +165,17 @@ object Graph {
* path from the target to the source (this is because we want to calculate the weight of the edges correctly). The
* graph @param g is optimized for querying the incoming edges given a vertex.
*
* @param g the graph on which will be performed the search
* @param sourceNode the starting node of the path we're looking for (payer)
* @param targetNode the destination node of the path
* @param ignoredEdges channels that should be avoided
* @param ignoredVertices nodes that should be avoided
* @param extraEdges additional edges that can be used (e.g. private channels from invoices)
* @param initialWeight weight that will be applied to the target node
* @param boundaries a predicate function that can be used to impose limits on the outcome of the search
* @param currentBlockHeight the height of the chain tip (latest block)
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param g the graph on which will be performed the search
* @param sourceNode the starting node of the path we're looking for (payer)
* @param targetNode the destination node of the path
* @param ignoredEdges channels that should be avoided
* @param ignoredVertices nodes that should be avoided
* @param extraEdges additional edges that can be used (e.g. private channels from invoices)
* @param initialWeight weight that will be applied to the target node
* @param boundaries a predicate function that can be used to impose limits on the outcome of the search
* @param currentBlockHeight the height of the chain tip (latest block)
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
private def dijkstraShortestPath(g: DirectedGraph,
sourceNode: PublicKey,
Expand All @@ -183,7 +186,8 @@ object Graph {
initialWeight: RichWeight,
boundaries: RichWeight => Boolean,
currentBlockHeight: Long,
wr: Option[WeightRatios]): Seq[GraphEdge] = {
wr: Option[WeightRatios],
includeLocalChannelCost: Boolean): Seq[GraphEdge] = {
// the graph does not contain source/destination nodes
val sourceNotInGraph = !g.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode)
val targetNotInGraph = !g.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode)
Expand Down Expand Up @@ -221,7 +225,7 @@ object Graph {
val neighbor = edge.desc.a
// NB: this contains the amount (including fees) that will need to be sent to `neighbor`, but the amount that
// will be relayed through that edge is the one in `currentWeight`.
val neighborWeight = addEdgeWeight(sourceNode, edge, current.weight, currentBlockHeight, wr)
val neighborWeight = addEdgeWeight(sourceNode, edge, current.weight, currentBlockHeight, wr, includeLocalChannelCost)
val canRelayAmount = current.weight.cost <= edge.capacity &&
edge.balance_opt.forall(current.weight.cost <= _) &&
edge.update.htlcMaximumMsat.forall(current.weight.cost <= _) &&
Expand Down Expand Up @@ -258,16 +262,17 @@ object Graph {
/**
* Add the given edge to the path and compute the new weight.
*
* @param sender node sending the payment
* @param edge the edge we want to cross
* @param prev weight of the rest of the path
* @param currentBlockHeight the height of the chain tip (latest block).
* @param weightRatios ratios used to 'weight' edges when searching for the shortest path
* @param sender node sending the payment
* @param edge the edge we want to cross
* @param prev weight of the rest of the path
* @param currentBlockHeight the height of the chain tip (latest block).
* @param weightRatios ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
private def addEdgeWeight(sender: PublicKey, edge: GraphEdge, prev: RichWeight, currentBlockHeight: Long, weightRatios: Option[WeightRatios]): RichWeight = {
val totalCost = if (edge.desc.a == sender) prev.cost else addEdgeFees(edge, prev.cost)
private def addEdgeWeight(sender: PublicKey, edge: GraphEdge, prev: RichWeight, currentBlockHeight: Long, weightRatios: Option[WeightRatios], includeLocalChannelCost: Boolean): RichWeight = {
val totalCost = if (edge.desc.a == sender && !includeLocalChannelCost) prev.cost else addEdgeFees(edge, prev.cost)
val fee = totalCost - prev.cost
val totalCltv = if (edge.desc.a == sender) prev.cltv else prev.cltv + edge.update.cltvExpiryDelta
val totalCltv = if (edge.desc.a == sender && !includeLocalChannelCost) prev.cltv else prev.cltv + edge.update.cltvExpiryDelta
val factor = weightRatios match {
case None =>
1.0
Expand Down Expand Up @@ -322,15 +327,16 @@ object Graph {
* Calculates the total weighted cost of a path.
* Note that the first hop from the sender is ignored: we don't pay a routing fee to ourselves.
*
* @param sender node sending the payment
* @param path candidate path.
* @param amount amount to send to the last node.
* @param currentBlockHeight the height of the chain tip (latest block).
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param sender node sending the payment
* @param path candidate path.
* @param amount amount to send to the last node.
* @param currentBlockHeight the height of the chain tip (latest block).
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
def pathWeight(sender: PublicKey, path: Seq[GraphEdge], amount: MilliSatoshi, currentBlockHeight: Long, wr: Option[WeightRatios]): RichWeight = {
def pathWeight(sender: PublicKey, path: Seq[GraphEdge], amount: MilliSatoshi, currentBlockHeight: Long, wr: Option[WeightRatios], includeLocalChannelCost: Boolean): RichWeight = {
path.foldRight(RichWeight(amount, 0, CltvExpiryDelta(0), 0)) { (edge, prev) =>
addEdgeWeight(sender, edge, prev, currentBlockHeight, wr)
addEdgeWeight(sender, edge, prev, currentBlockHeight, wr, includeLocalChannelCost)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ object RouteCalculation {
capacityFactor = routerConf.searchRatioChannelCapacity
))
},
mpp = MultiPartParams(routerConf.mppMinPartAmount, routerConf.mppMaxParts)
mpp = MultiPartParams(routerConf.mppMinPartAmount, routerConf.mppMaxParts),
includeLocalChannelCost = false,
)

/**
Expand Down Expand Up @@ -257,7 +258,7 @@ object RouteCalculation {

val boundaries: RichWeight => Boolean = { weight => feeOk(weight.cost - amount) && lengthOk(weight.length) && cltvOk(weight.cltv) }

val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries)
val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries, routeParams.includeLocalChannelCost)
if (foundRoutes.nonEmpty) {
val (directRoutes, indirectRoutes) = foundRoutes.partition(_.path.length == 1)
val routes = if (routeParams.randomize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ object Router {

case class MultiPartParams(minPartAmount: MilliSatoshi, maxParts: Int)

case class RouteParams(randomize: Boolean, maxFeeBase: MilliSatoshi, maxFeePct: Double, routeMaxLength: Int, routeMaxCltv: CltvExpiryDelta, ratios: Option[WeightRatios], mpp: MultiPartParams) {
case class RouteParams(randomize: Boolean, maxFeeBase: MilliSatoshi, maxFeePct: Double, routeMaxLength: Int, routeMaxCltv: CltvExpiryDelta, ratios: Option[WeightRatios], mpp: MultiPartParams, includeLocalChannelCost: Boolean) {
def getMaxFee(amount: MilliSatoshi): MilliSatoshi = {
// The payment fee must satisfy either the flat fee or the percentage fee, not necessarily both.
maxFeeBase.max(amount * maxFeePct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ abstract class IntegrationSpec extends TestKitBaseClass with BitcoindService wit
ageFactor = 0,
capacityFactor = 0
)),
mpp = MultiPartParams(15000000 msat, 6)
mpp = MultiPartParams(15000000 msat, 6),
includeLocalChannelCost = false,
))

// we need to provide a value higher than every node's fulfill-safety-before-timeout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ object MultiPartPaymentLifecycleSpec {
val expiry = CltvExpiry(1105)
val finalAmount = 1000000 msat
val finalRecipient = randomKey().publicKey
val routeParams = RouteParams(randomize = false, 15000 msat, 0.01, 6, CltvExpiryDelta(1008), None, MultiPartParams(1000 msat, 5))
val routeParams = RouteParams(randomize = false, 15000 msat, 0.01, 6, CltvExpiryDelta(1008), None, MultiPartParams(1000 msat, 5), false)
val maxFee = 15000 msat // max fee for the defaultAmount

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
import payFixture._
import cfg._

val request = SendPayment(sender.ref, d, Onion.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret.get), 5, routeParams = Some(RouteParams(randomize = false, 100 msat, 0.0, 20, CltvExpiryDelta(2016), None, MultiPartParams(10000 msat, 5))))
val request = SendPayment(sender.ref, d, Onion.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret.get), 5, routeParams = Some(RouteParams(randomize = false, 100 msat, 0.0, 20, CltvExpiryDelta(2016), None, MultiPartParams(10000 msat, 5), false)))
sender.send(paymentFSM, request)
val routeRequest = routerForwarder.expectMsgType[RouteRequest]
val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl

val routeRequest = router.expectMessageType[RouteRequest]
val routeParams = routeRequest.routeParams.get
val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, outgoingAmount)
assert(routeParams.maxFeePct === 0) // should be disabled
assert(routeParams.maxFeeBase === incomingAmount - outgoingAmount - fee) // we collect our fee and then use what remains for the rest of the route
assert(routeParams.routeMaxCltv === incomingSinglePart.add.cltvExpiry - outgoingExpiry - nodeParams.expiryDelta) // we apply our cltv delta
assert(routeParams.maxFeeBase === incomingAmount - outgoingAmount)
assert(routeParams.routeMaxCltv === incomingSinglePart.add.cltvExpiry - outgoingExpiry)
assert(routeParams.includeLocalChannelCost)
}

test("relay incoming multi-part payment") { f =>
Expand Down
Loading

0 comments on commit f857368

Please sign in to comment.