Skip to content

Commit

Permalink
Use proper data type for timestamps in Postgres (#1778)
Browse files Browse the repository at this point in the history
Did some refactoring in tests and introduced a new `migrationCheck`
helper method.

Note that the change of data type in sqlite for the `commitment_number`
field (from `BLOB` to `INTEGER`) is not a migration. If the table has
been created before, it will stay like it was. It doesn't matter due to
how sqlite stores data, and we make sure in tests that there is no
regression.
  • Loading branch information
pm47 authored Apr 22, 2021
1 parent 4a1dfd2 commit e14c40d
Show file tree
Hide file tree
Showing 7 changed files with 534 additions and 395 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}

import java.sql.{Connection, ResultSet, Statement}
import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue
Expand Down Expand Up @@ -123,18 +123,16 @@ trait JdbcUtils {

def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = {
val s = rs.getString(columnLabel)
if (rs.wasNull()) None else {
Some(ByteVector32(ByteVector.fromValidHex(s)))
}
if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s)))
}

def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))

def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))

def getByteVectorNullable(columnLabel: String): ByteVector = {
def getByteVectorNullable(columnLabel: String): Option[ByteVector] = {
val result = rs.getBytes(columnLabel)
if (rs.wasNull()) ByteVector.empty else ByteVector(result)
if (rs.wasNull()) None else Some(ByteVector(result))
}

def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
Expand Down Expand Up @@ -164,6 +162,11 @@ trait JdbcUtils {
if (rs.wasNull()) None else Some(MilliSatoshi(result))
}

def getTimestampNullable(label: String): Option[Timestamp] = {
val result = rs.getTimestamp(label)
if (rs.wasNull()) None else Some(result)
}

}

object ExtendedResultSet {
Expand Down
85 changes: 50 additions & 35 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey
import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong}
import grizzled.slf4j.Logging

import java.sql.Statement
import java.sql.{Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue
Expand All @@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
import ExtendedResultSet._

val DB_NAME = "audit"
val CURRENT_VERSION = 5
val CURRENT_VERSION = 6

case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long)

Expand All @@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)")
}

def migration56(statement: Statement): Unit = {
statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'")
}

getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)")
statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")
statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)")

statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)")
statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)")
Expand All @@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
case Some(v@4) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration45(statement)
migration56(statement)
case Some(v@5) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration56(statement)
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
Expand All @@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setBoolean(4, e.isFunder)
statement.setBoolean(5, e.isPrivate)
statement.setString(6, e.event.label)
statement.setLong(7, System.currentTimeMillis)
statement.setTimestamp(7, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
Expand All @@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(7, e.paymentPreimage.toHex)
statement.setString(8, e.recipientNodeId.value.toHex)
statement.setString(9, p.toChannelId.toHex)
statement.setLong(10, p.timestamp)
statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
statement.addBatch()
})
statement.executeBatch()
Expand All @@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setLong(1, p.amount.toLong)
statement.setString(2, e.paymentHash.toHex)
statement.setString(3, p.fromChannelId.toHex)
statement.setLong(4, p.timestamp)
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp)))
statement.addBatch()
})
statement.executeBatch()
Expand All @@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(1, e.paymentHash.toHex)
statement.setLong(2, nextTrampolineAmount.toLong)
statement.setString(3, nextTrampolineNodeId.value.toHex)
statement.setLong(4, e.timestamp)
statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
statement.executeUpdate()
}
// trampoline relayed payments do MPP aggregation and may have M inputs and N outputs
Expand All @@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, p.channelId.toHex)
statement.setString(4, p.direction)
statement.setString(5, p.relayType)
statement.setLong(6, e.timestamp)
statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp)))
statement.executeUpdate()
}
}
Expand All @@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, e.tx.txid.toHex)
statement.setLong(4, e.fee.toLong)
statement.setString(5, e.txType)
statement.setLong(6, System.currentTimeMillis)
statement.setTimestamp(6, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
Expand All @@ -189,17 +204,17 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setString(3, errorName)
statement.setString(4, errorMessage)
statement.setBoolean(5, e.isFatal)
statement.setLong(6, System.currentTimeMillis)
statement.setTimestamp(6, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
}
}

override def listSent(from: Long, to: Long): Seq[PaymentSent] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var sentByParentId = Map.empty[UUID, PaymentSent]
while (rs.next()) {
Expand All @@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
Expand All @@ -229,17 +244,17 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {

override def listReceived(from: Long, to: Long): Seq[PaymentReceived] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
Expand All @@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
inTransaction { pg =>
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
Expand All @@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
trampolineByHash += (paymentHash -> (amount, nodeId))
}
}
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
while (rs.next()) {
Expand All @@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getLong("timestamp"))
rs.getTimestamp("timestamp").getTime)
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
relayedByHash.flatMap {
Expand All @@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {

override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement =>
statement.setLong(1, from)
statement.setLong(2, to)
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var q: Queue[NetworkFee] = Queue()
while (rs.next()) {
Expand All @@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
txId = rs.getByteVector32FromHex("tx_id"),
fee = Satoshi(rs.getLong("fee_sat")),
txType = rs.getString("tx_type"),
timestamp = rs.getLong("timestamp"))
timestamp = rs.getTimestamp("timestamp").getTime)
}
q
}
Expand Down
Loading

0 comments on commit e14c40d

Please sign in to comment.