Skip to content

Commit

Permalink
Introduce actor factories (#1744)
Browse files Browse the repository at this point in the history
This removes unnecessary fields and allows more flexibility in tests.
  • Loading branch information
t-bast authored Mar 31, 2021
1 parent e5429eb commit c6a76af
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 136 deletions.
14 changes: 9 additions & 5 deletions eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager}
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.{Databases, DbEventHandler, FileBackupHandler}
import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard}
import fr.acinq.eclair.io.{ClientSpawner, Peer, Server, Switchboard}
import fr.acinq.eclair.payment.receive.PaymentHandler
import fr.acinq.eclair.payment.relay.Relayer
import fr.acinq.eclair.payment.send.{Autoprobe, PaymentInitiator}
Expand Down Expand Up @@ -290,8 +290,8 @@ class Setup(datadir: File,
new ElectrumEclairWallet(electrumWallet, nodeParams.chainHash)
}
_ = wallet.getReceiveAddress.map(address => logger.info(s"initial wallet address=$address"))
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox

// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
backupHandler = if (config.getBoolean("enable-db-backup")) {
nodeParams.db match {
case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
Expand All @@ -314,10 +314,14 @@ class Setup(datadir: File,
// Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system,
// we want to make sure the handler for post-restart broken HTLCs has finished initializing.
_ <- postRestartCleanUpInitialized.future
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, watcher, relayer, wallet), "switchboard", SupervisorStrategy.Resume))

channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, wallet)
peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory)

switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
clientSpawner = system.actorOf(SimpleSupervisor.props(ClientSpawner.props(nodeParams.keyPair, nodeParams.socksProxy_opt, nodeParams.peerConnectionConf, switchboard, router), "client-spawner", SupervisorStrategy.Restart))
server = system.actorOf(SimpleSupervisor.props(Server.props(nodeParams.keyPair, nodeParams.peerConnectionConf, switchboard, router, serverBindingAddress, Some(tcpBound)), "server", SupervisorStrategy.Restart))
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, register), "payment-initiator", SupervisorStrategy.Restart))
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)), "payment-initiator", SupervisorStrategy.Restart))
_ = for (i <- 0 until config.getInt("autoprobe-count")) yield system.actorOf(SimpleSupervisor.props(Autoprobe.props(nodeParams, router, paymentInitiator), s"payment-autoprobe-$i", SupervisorStrategy.Restart))

kit = Kit(
Expand Down Expand Up @@ -381,11 +385,11 @@ class Setup(datadir: File,

}

// @formatter:off
object Setup {
final case class Seeds(nodeSeed: ByteVector, channelSeed: ByteVector)
}

// @formatter:off
sealed trait Bitcoin
case class Bitcoind(bitcoinClient: BasicBitcoinJsonRPCClient) extends Bitcoin
case class Electrum(electrumClient: ActorRef) extends Bitcoin
Expand Down
23 changes: 16 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package fr.acinq.eclair.io

import akka.actor.{Actor, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
import akka.actor.{Actor, ActorContext, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
import akka.event.Logging.MDC
import akka.event.{BusLogging, DiagnosticLoggingAdapter}
import akka.util.Timeout
Expand Down Expand Up @@ -48,7 +48,7 @@ import java.net.InetSocketAddress
*
* Created by PM on 26/08/2016.
*/
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {

import Peer._

Expand All @@ -57,7 +57,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
when(INSTANTIATING) {
case Event(Init(storedChannels), _) =>
val channels = storedChannels.map { state =>
val channel = spawnChannel(nodeParams, origin_opt = None)
val channel = spawnChannel(origin_opt = None)
channel ! INPUT_RESTORED(state)
FinalChannelId(state.channelId) -> channel
}.toMap
Expand Down Expand Up @@ -294,12 +294,12 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
(Helpers.getFinalScriptPubKey(wallet, nodeParams.chainHash), None)
}
val localParams = makeChannelParams(nodeParams, features, finalScript, walletStaticPaymentBasepoint, funder, fundingAmount)
val channel = spawnChannel(nodeParams, origin_opt)
val channel = spawnChannel(origin_opt)
(channel, localParams)
}

def spawnChannel(nodeParams: NodeParams, origin_opt: Option[ActorRef]): ActorRef = {
val channel = context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
def spawnChannel(origin_opt: Option[ActorRef]): ActorRef = {
val channel = channelFactory.spawn(context, remoteNodeId, origin_opt)
context watch channel
channel
}
Expand Down Expand Up @@ -353,7 +353,16 @@ object Peer {
val UNKNOWN_CHANNEL_MESSAGE: ByteVector = ByteVector.view("unknown channel".getBytes())
// @formatter:on

def props(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet): Props = Props(new Peer(nodeParams, remoteNodeId, watcher, relayer: ActorRef, wallet))
trait ChannelFactory {
def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef
}

case class SimpleChannelFactory(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends ChannelFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef =
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
}

def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: ChannelFactory): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory))

// @formatter:off

Expand Down
17 changes: 13 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package fr.acinq.eclair.io

import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.NodeParams
import fr.acinq.eclair.blockchain.EclairWallet
Expand All @@ -29,7 +29,7 @@ import fr.acinq.eclair.router.Router.RouterConf
* Ties network connections to peers.
* Created by PM on 14/02/2017.
*/
class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends Actor with ActorLogging {
class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) extends Actor with ActorLogging {

import Switchboard._

Expand Down Expand Up @@ -103,7 +103,7 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,
*/
def getPeer(remoteNodeId: PublicKey): Option[ActorRef] = context.child(peerActorName(remoteNodeId))

def createPeer(remoteNodeId: PublicKey): ActorRef = context.actorOf(Peer.props(nodeParams, remoteNodeId, watcher, relayer, wallet), name = peerActorName(remoteNodeId))
def createPeer(remoteNodeId: PublicKey): ActorRef = peerFactory.spawn(context, remoteNodeId)

def createOrGetPeer(remoteNodeId: PublicKey, offlineChannels: Set[HasCommitments]): ActorRef = {
getPeer(remoteNodeId) match {
Expand All @@ -124,7 +124,16 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,

object Switchboard {

def props(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) = Props(new Switchboard(nodeParams, watcher, relayer, wallet))
trait PeerFactory {
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
}

case class SimplePeerFactory(nodeParams: NodeParams, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory), name = peerActorName(remoteNodeId))
}

def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))

def peerActorName(remoteNodeId: PublicKey): String = s"peer-$remoteNodeId"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ import fr.acinq.eclair.payment.OutgoingPacket.Upstream
import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart
import fr.acinq.eclair.payment.relay.NodeRelay.FsmFactory
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment}
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle}
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle}
import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound}
import fr.acinq.eclair.wire.protocol._
Expand All @@ -60,29 +59,32 @@ object NodeRelay {
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
// @formatter:on

def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
paymentHash_opt = Some(paymentHash))) {
new NodeRelay(nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)()
}
}
trait OutgoingPaymentFactory {
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef
}

/**
* This is supposed to be overridden in tests
*/
class FsmFactory {
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: ActorRef, register: ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
case class SimpleOutgoingPaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends OutgoingPaymentFactory {
val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)

override def spawnOutgoingPayFSM(context: ActorContext[Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
if (multiPart) {
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, register))
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, paymentFactory))
} else {
context.toClassic.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register))
}
}
}

def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
paymentHash_opt = Some(paymentHash))) {
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
}
}

def validateRelay(nodeParams: NodeParams, upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload): Option[FailureMessage] = {
val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, payloadOut.amountToForward)
if (upstream.amountIn - payloadOut.amountToForward < fee) {
Expand Down Expand Up @@ -139,12 +141,11 @@ object NodeRelay {
*/
class NodeRelay private(nodeParams: NodeParams,
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
router: ActorRef,
register: ActorRef,
relayId: UUID,
paymentHash: ByteVector32,
context: ActorContext[NodeRelay.Command],
fsmFactory: FsmFactory) {
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) {

import NodeRelay._

Expand Down Expand Up @@ -285,20 +286,20 @@ class NodeRelay private(nodeParams: NodeParams,
case Some(paymentSecret) if Features(features).hasFeature(Features.BasicMultiPartPayment) =>
context.log.debug("sending the payment to non-trampoline recipient using MPP")
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
payFSM ! payment
payFSM
case _ =>
context.log.debug("sending the payment to non-trampoline recipient without MPP")
val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret)
val payment = SendPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = false)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false)
payFSM ! payment
payFSM
}
case None =>
context.log.debug("sending the payment to the next trampoline node")
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
val paymentSecret = randomBytes32 // we generate a new secret to protect against probing attacks
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = Some(routeParams), additionalTlvs = Seq(OnionTlv.TrampolineOnion(packetOut)))
payFSM ! payment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ object NodeRelayer {
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, router, register, relayId, paymentHash), relayId.toString)
val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register)
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc to new handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, router, register, children + (paymentHash -> handler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit
* Sender for a multi-part payment (see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#basic-multi-part-payments).
* The payment will be split into multiple sub-payments that will be sent in parallel.
*/
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {

import MultiPartPaymentLifecycle._

Expand Down Expand Up @@ -202,13 +202,13 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
case Event(_: Status.Failure, _) => stay
}

def spawnChildPaymentFsm(childId: UUID): ActorRef = {
private def spawnChildPaymentFsm(childId: UUID): ActorRef = {
val upstream = cfg.upstream match {
case Upstream.Local(_) => Upstream.Local(childId)
case _ => cfg.upstream
}
val childCfg = cfg.copy(id = childId, publishEvent = false, upstream = upstream)
context.actorOf(PaymentLifecycle.props(nodeParams, childCfg, router, register))
paymentFactory.spawnOutgoingPayment(context, childCfg)
}

private def gotoAbortedOrStop(d: PaymentAborted): State = {
Expand Down Expand Up @@ -265,7 +265,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,

object MultiPartPaymentLifecycle {

def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, register))
def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, paymentFactory))

/**
* Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding
Expand Down
Loading

0 comments on commit c6a76af

Please sign in to comment.