Skip to content

Commit

Permalink
Make result set an iterable (#1823)
Browse files Browse the repository at this point in the history
This allows us to use the full power of scala collections, to iterate
over results, convert to options, etc. while staying purely functional
and immutable.

There is a catch though: the iterator is lazy, it must be materialized
before the result set is closed, by converting the end result in a
collection or an option. In other words, database methods must never
return an `Iterable` or `Iterator`.
  • Loading branch information
pm47 authored May 25, 2021
1 parent f829a2e commit 4dc2910
Show file tree
Hide file tree
Showing 23 changed files with 363 additions and 490 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, Props}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.NodeParams
import fr.acinq.eclair.channel.Helpers.Closing.{ClosingType, CurrentRemoteClose, LocalClose, MutualClose, NextRemoteClose, RecoveryClose, RevokedClose}
import fr.acinq.eclair.channel.Helpers.Closing._
import fr.acinq.eclair.channel.Monitoring.{Metrics => ChannelMetrics, Tags => ChannelTags}
import fr.acinq.eclair.channel._
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.eclair.blockchain.fee.FeeratesPerKB

import java.io.Closeable

/**
* This database stores the fee rates retrieved by a [[fr.acinq.eclair.blockchain.fee.FeeProvider]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@

package fr.acinq.eclair.db

import java.io.File
import java.nio.file.{Files, StandardCopyOption}

import akka.actor.{Actor, ActorLogging, Props}
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
import fr.acinq.eclair.KamonExt
import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.Monitoring.Metrics

import java.io.File
import java.nio.file.{Files, StandardCopyOption}
import scala.sys.process.Process
import scala.util.{Failure, Success, Try}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}

import java.io.Closeable
import scala.collection.immutable.SortedMap

trait NetworkDb extends Closeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

package fr.acinq.eclair.db

import java.io.Closeable
import java.util.UUID

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop}
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}

import java.io.Closeable
import java.util.UUID

trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {
Expand Down
4 changes: 2 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.wire.protocol.NodeAddress

import java.io.Closeable

trait PeersDb extends Closeable {

def addOrUpdatePeer(nodeId: PublicKey, address: NodeAddress): Unit
Expand Down
42 changes: 24 additions & 18 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ package fr.acinq.eclair.db.jdbc
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.MilliSatoshi
import org.sqlite.SQLiteConnection
import scodec.Codec
import scodec.Decoder
import scodec.bits.{BitVector, ByteVector}

import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue

trait JdbcUtils {

import ExtendedResultSet._

def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
val connection = dataSource.getConnection()
try {
Expand Down Expand Up @@ -72,15 +73,16 @@ trait JdbcUtils {
def getVersion(statement: Statement, db_name: String): Option[Int] = {
createVersionTable(statement)
// if there was a previous version installed, this will return a different value from current version
val rs = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
if (rs.next()) Some(rs.getInt("version")) else None
statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
.map(rs => rs.getInt("version"))
.headOption
}

/**
* Updates the version for a particular logical database, it will overwrite the previous version.
*
* NB: we could define this method in [[fr.acinq.eclair.db.sqlite.SqliteUtils]] and [[fr.acinq.eclair.db.pg.PgUtils]]
* but it would make testing more complicated because we need to use one or the other depending on the backend.
* but it would make testing more complicated because we need to use one or the other depending on the backend.
*/
def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = {
createVersionTable(statement)
Expand All @@ -96,20 +98,25 @@ trait JdbcUtils {
}
}

/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
case class ExtendedResultSet(rs: ResultSet) extends Iterable[ResultSet] {

/**
* Iterates over all rows of a result set.
*
* Careful: the iterator is lazy, it must be materialized before the [[ResultSet]] is closed, by converting the end
* result in a collection or an option.
*/
override def iterator: Iterator[ResultSet] = {
// @formatter:off
new Iterator[ResultSet] {
def hasNext: Boolean = rs.next()
def next(): ResultSet = rs
}
// @formatter:on
}
q
}

case class ExtendedResultSet(rs: ResultSet) {
/** This helper assumes that there is a "data" column available, that can be decoded with the provided codec */
def mapCodec[T](codec: Decoder[T]): Iterable[T] = rs.map(rs => codec.decode(BitVector(rs.getBytes("data"))).require.value)

def getByteVectorFromHex(columnLabel: String): ByteVector = {
val s = rs.getString(columnLabel).stripPrefix("\\x")
Expand Down Expand Up @@ -166,7 +173,6 @@ trait JdbcUtils {
val result = rs.getTimestamp(label)
if (rs.wasNull()) None else Some(result)
}

}

object ExtendedResultSet {
Expand Down
162 changes: 76 additions & 86 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import java.sql.{Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue

class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {

Expand Down Expand Up @@ -215,30 +214,28 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var sentByParentId = Map.empty[UUID, PaymentSent]
while (rs.next()) {
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
val part = PaymentSent.PartialPayment(
UUID.fromString(rs.getString("payment_id")),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
parentId,
rs.getByteVector32FromHex("payment_hash"),
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
}
sentByParentId = sentByParentId + (parentId -> sent)
}
sentByParentId.values.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[UUID, PaymentSent]) { (sentByParentId, rs) =>
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
val part = PaymentSent.PartialPayment(
UUID.fromString(rs.getString("payment_id")),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
parentId,
rs.getByteVector32FromHex("payment_hash"),
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
}
sentByParentId + (parentId -> sent)
}.values.toSeq.sortBy(_.timestamp)
}
}

Expand All @@ -247,98 +244,91 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
}
receivedByHash = receivedByHash + (paymentHash -> received)
}
receivedByHash.values.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
}
receivedByHash + (paymentHash -> received)
}.values.toSeq.sortBy(_.timestamp)
}
}

override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
inTransaction { pg =>
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
val trampolineByHash = using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val amount = MilliSatoshi(rs.getLong("amount_msat"))
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
trampolineByHash += (paymentHash -> (amount, nodeId))
}
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]) { (trampolineByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val amount = MilliSatoshi(rs.getLong("amount_msat"))
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
trampolineByHash + (paymentHash -> (amount, nodeId))
}
}
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
val relayedByHash = using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = RelayedPart(
rs.getByteVector32FromHex("channel_id"),
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getTimestamp("timestamp").getTime)
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
relayedByHash.flatMap {
case (paymentHash, parts) =>
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
parts.headOption match {
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
}
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
case _ => Nil
}
}.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, Seq[RelayedPart]]) { (relayedByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = RelayedPart(
rs.getByteVector32FromHex("channel_id"),
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getTimestamp("timestamp").getTime)
relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
}
relayedByHash.flatMap {
case (paymentHash, parts) =>
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
parts.headOption match {
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
}
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
case _ => Nil
}
}.toSeq.sortBy(_.timestamp)
}

override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var q: Queue[NetworkFee] = Queue()
while (rs.next()) {
q = q :+ NetworkFee(
statement.executeQuery().map { rs =>
NetworkFee(
remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")),
channelId = rs.getByteVector32FromHex("channel_id"),
txId = rs.getByteVector32FromHex("tx_id"),
fee = Satoshi(rs.getLong("fee_sat")),
txType = rs.getString("tx_type"),
timestamp = rs.getTimestamp("timestamp").getTime)
}
q
}.toSeq
}
}

override def stats(from: Long, to: Long): Seq[Stats] = {
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) =>
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) =>
feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee))
}
case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String)
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) =>
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) =>
// NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones.
val current = e match {
case c: ChannelPaymentRelayed => Map(
Expand Down
Loading

0 comments on commit 4dc2910

Please sign in to comment.