Skip to content

Commit

Permalink
Fix API regression (#1729)
Browse files Browse the repository at this point in the history
We incorrectly applied error handlers at each sub-route instead of applying
it after grouping all sub-routes together. The result was that only `getinfo`
could actually be called.
  • Loading branch information
t-bast authored Mar 12, 2021
1 parent ded5ce0 commit 8dc64db
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 45 deletions.
11 changes: 6 additions & 5 deletions eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ trait Service extends EclairDirectives with WebSocket with Node with Channel wit

/**
* Allows router access to the API password as configured in eclair.conf
*
* @return
*/
def password: String

Expand All @@ -49,9 +47,12 @@ trait Service extends EclairDirectives with WebSocket with Node with Channel wit
implicit val mat: Materializer

/**
* Collect routes from all sub-routers here. This is the main entrypoint for the global
* http request router of the API service.
* Collect routes from all sub-routers here.
* This is the main entrypoint for the global http request router of the API service.
* This is where we handle errors to ensure all routes are correctly tried before rejecting.
*/
val route: Route = nodeRoutes ~ channelRoutes ~ feeRoutes ~ pathFindingRoutes ~ invoiceRoutes ~ paymentRoutes ~ messageRoutes ~ onChainRoutes ~ webSocket
val route: Route = securedHandler {
nodeRoutes ~ channelRoutes ~ feeRoutes ~ pathFindingRoutes ~ invoiceRoutes ~ paymentRoutes ~ messageRoutes ~ onChainRoutes ~ webSocket
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ trait AuthDirective {
*/
def authenticated: Directive0 = authenticateBasicAsync(realm = "Access restricted", userPassAuthenticator).tflatMap { _ => pass }


private def userPassAuthenticator(credentials: Credentials): Future[Option[String]] = credentials match {
case p@Credentials.Provided(id) if p.verify(password) => Future.successful(Some(id))
case _ => akka.pattern.after(1 second, using = actorSystem.scheduler)(Future.successful(None))(actorSystem.dispatcher) // force a 1 sec pause to deter brute force
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ trait DefaultHeaders {
/**
* Adds customHeaders to all http responses.
*/
def eclairHeaders:Directive0 = respondWithDefaultHeaders(customHeaders)

def eclairHeaders: Directive0 = respondWithDefaultHeaders(customHeaders)

private val customHeaders = `Access-Control-Allow-Headers`("Content-Type, Authorization") ::
`Access-Control-Allow-Methods`(POST) ::
`Cache-Control`(public, `no-store`, `max-age`(0)) :: Nil
`Access-Control-Allow-Methods`(POST) ::
`Cache-Control`(public, `no-store`, `max-age`(0)) :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,28 @@ import fr.acinq.eclair.api.Service

import scala.concurrent.duration.DurationInt

class EclairDirectives extends Directives with TimeoutDirective with ErrorDirective with AuthDirective with DefaultHeaders with ExtraDirectives { this: Service =>
class EclairDirectives extends Directives with TimeoutDirective with ErrorDirective with AuthDirective with DefaultHeaders with ExtraDirectives {
this: Service =>

/**
* Prepares inner routes to be exposed as public API with default headers, error handlers and basic authentication.
* Prepares inner routes to be exposed as public API with default headers, basic authentication and error handling.
* Must be applied *after* aggregating all the inner routes.
*/
private def securedHandler:Directive0 = eclairHeaders & handled & authenticated
def securedHandler: Directive0 = eclairHeaders & handled & authenticated

/**
* Provides a Timeout to the inner route either from request param or the default.
*/
private def standardHandler:Directive1[Timeout] = toStrictEntity(5 seconds) & withTimeout
private def standardHandler: Directive1[Timeout] = toStrictEntity(5 seconds) & withTimeout

/**
* Handles POST requests with given simple path. The inner route is wrapped in a standard handler and provides a Timeout as parameter.
*/
def postRequest(p:String):Directive1[Timeout] = securedHandler & post & path(p) & standardHandler
def postRequest(p: String): Directive1[Timeout] = standardHandler & post & path(p)

/**
* Handles GET requests with given simple path. The inner route is wrapped in a standard handler and provides a Timeout as parameter.
*/
def getRequest(p:String):Directive1[Timeout] = securedHandler & get & path(p) & standardHandler
def getRequest(p: String): Directive1[Timeout] = standardHandler & get & path(p)

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ trait ErrorDirective {
this: Service with EclairDirectives =>

/**
* Handles API exceptions and rejections. Produces json formatted
* error responses.
* Handles API exceptions and rejections. Produces json formatted error responses.
*/
def handled: Directive0 = handleExceptions(apiExceptionHandler) &
handleRejections(apiRejectionHandler)

def handled: Directive0 = handleExceptions(apiExceptionHandler) & handleRejections(apiRejectionHandler)

import fr.acinq.eclair.api.serde.JsonSupport.{formats, marshaller, serialization}

Expand All @@ -45,9 +42,7 @@ trait ErrorDirective {
// map all the rejections to a JSON error object ErrorResponse
private val apiRejectionHandler = RejectionHandler.default.mapRejectionResponse {
case res@HttpResponse(_, _, ent: HttpEntity.Strict, _) =>
res.withEntity(
HttpEntity(ContentTypes.`application/json`, serialization.writePretty(ErrorResponse(ent.data.utf8String)))
)
res.withEntity(HttpEntity(ContentTypes.`application/json`, serialization.writePretty(ErrorResponse(ent.data.utf8String))))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package fr.acinq.eclair.api.directives

import fr.acinq.eclair.api.serde.JsonSupport.serialization
import akka.http.scaladsl.common.{NameReceptacle, NameUnmarshallerReceptacle}
import akka.http.scaladsl.marshalling.ToResponseMarshaller
import akka.http.scaladsl.model.StatusCodes.NotFound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@ import akka.http.scaladsl.model.{ContentTypes, HttpRequest, HttpResponse, Status
import akka.http.scaladsl.server.{Directive0, Directive1, Directives}
import akka.util.Timeout
import fr.acinq.eclair.api.serde.FormParamExtractors._
import fr.acinq.eclair.api.serde.JsonSupport._
import fr.acinq.eclair.api.serde.JsonSupport
import fr.acinq.eclair.api.serde.JsonSupport._

import scala.concurrent.duration.DurationInt

trait TimeoutDirective extends Directives {

import JsonSupport.{formats, serialization}


/**
* Extracts a given request timeout from an optional form field. Provides either the
* extracted Timeout or a default Timeout to the inner route.
*/
def withTimeout:Directive1[Timeout] = extractTimeout.tflatMap { timeout =>
def withTimeout: Directive1[Timeout] = extractTimeout.tflatMap { timeout =>
withTimeoutRequest(timeout._1) & provide(timeout._1)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import akka.http.scaladsl.server.{MalformedFormFieldRejection, Route}
import akka.util.Timeout
import fr.acinq.bitcoin.Satoshi
import fr.acinq.eclair.MilliSatoshi
import fr.acinq.eclair.api.serde.FormParamExtractors._
import fr.acinq.eclair.api.Service
import fr.acinq.eclair.api.directives.EclairDirectives
import fr.acinq.eclair.api.serde.FormParamExtractors._
import fr.acinq.eclair.blockchain.fee.FeeratePerByte
import scodec.bits.ByteVector

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

package fr.acinq.eclair.api.handlers

import akka.http.scaladsl.server.{MalformedFormFieldRejection, Route}
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.api.serde.FormParamExtractors._
import akka.http.scaladsl.server.Route
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.api.Service
import fr.acinq.eclair.api.directives.EclairDirectives
import fr.acinq.eclair.payment.PaymentRequest
import fr.acinq.eclair.api.serde.FormParamExtractors._

trait Invoice {
this: Service with EclairDirectives =>
Expand Down Expand Up @@ -60,6 +59,6 @@ trait Invoice {
}
}

val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice
val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package fr.acinq.eclair.api.handlers

import akka.http.scaladsl.server.Route
import fr.acinq.eclair.api.serde.FormParamExtractors._
import fr.acinq.eclair.api.Service
import fr.acinq.eclair.api.directives.EclairDirectives
import fr.acinq.eclair.api.serde.FormParamExtractors._
import scodec.bits.ByteVector

trait Message {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import akka.http.scaladsl.server.Route
import com.google.common.net.HostAndPort
import fr.acinq.eclair.api.Service
import fr.acinq.eclair.api.directives.EclairDirectives
import fr.acinq.eclair.io.NodeURI
import fr.acinq.eclair.api.serde.FormParamExtractors._
import fr.acinq.eclair.io.NodeURI

trait Node {
this: Service with EclairDirectives =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package fr.acinq.eclair.api.handlers
import akka.http.scaladsl.server.{MalformedFormFieldRejection, Route}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.api.serde.FormParamExtractors.pubkeyListUnmarshaller
import fr.acinq.eclair.api.Service
import fr.acinq.eclair.api.directives.EclairDirectives
import fr.acinq.eclair.api.serde.FormParamExtractors.{pubkeyListUnmarshaller, _}
import fr.acinq.eclair.payment.PaymentRequest
import fr.acinq.eclair.router.Router.{PredefinedChannelRoute, PredefinedNodeRoute}
import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi}
import fr.acinq.eclair.api.serde.FormParamExtractors._

import java.util.UUID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ trait WebSocket {
handleWebSocketMessages(makeSocketHandler)
}


// Init the websocket message flow
private lazy val makeSocketHandler: Flow[Message, TextMessage.Strict, NotUsed] = {

Expand Down Expand Up @@ -74,5 +73,4 @@ trait WebSocket {
.map(TextMessage.apply)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
override implicit val mat: Materializer = materializer
}

def mockApi(eclair:Eclair = mock[Eclair]): MockService = {
def mockApi(eclair: Eclair = mock[Eclair]): MockService = {
new MockService(eclair)
}

Expand Down Expand Up @@ -122,7 +122,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
test("API returns invalid channelId on invalid channelId form data") {
Post("/channel", FormData(Map("channelId" -> "hey")).toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockApi().channel) ~>
Route.seal(mockApi().route) ~>
check {
assert(handled)
assert(status == BadRequest)
Expand Down Expand Up @@ -247,6 +247,17 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
assert(entityAs[String] == "\"created channel 56d7d6eda04d80138270c49709f1eadb5ab4939e5061309ccdacdb98ce637d0e\"")
eclair.open(nodeId, 50000 sat, None, None, Some(100 msat, 10), None, None)(any[Timeout]).wasCalled(once)
}

Post("/open", FormData("nodeId" -> nodeId.toString(), "fundingSatoshis" -> "25000", "feeBaseMsat" -> "250", "feeProportionalMillionths" -> "10").toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
addHeader("Content-Type", "application/json") ~>
Route.seal(mockService.route) ~>
check {
assert(handled)
assert(status == OK)
assert(entityAs[String] == "\"created channel 56d7d6eda04d80138270c49709f1eadb5ab4939e5061309ccdacdb98ce637d0e\"")
eclair.open(nodeId, 25000 sat, None, None, Some(250 msat, 10), None, None)(any[Timeout]).wasCalled(once)
}
}

test("'close' method should accept channelIds and shortChannelIds") {
Expand Down Expand Up @@ -338,7 +349,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM

Post("/payinvoice", FormData("invoice" -> invoice).toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockService.payInvoice) ~>
Route.seal(mockService.route) ~>
check {
assert(handled)
assert(status == BadRequest)
Expand Down Expand Up @@ -372,6 +383,15 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM
assert(status == OK)
eclair.send(Some("42"), any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once)
}

Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "456", "feeThresholdSat" -> "10", "maxFeePct" -> "0.5").toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockService.route) ~>
check {
assert(handled)
assert(status == OK)
eclair.send(None, any, 456 msat, any, any, any, Some(10 sat), Some(0.5))(any[Timeout]).wasCalled(once)
}
}

test("'getreceivedinfo'") {
Expand Down Expand Up @@ -412,7 +432,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM

Post("/getreceivedinfo", FormData("paymentHash" -> expired.toHex).toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockService.getReceivedInfo) ~>
Route.seal(mockService.route) ~>
check {
assert(handled)
assert(status == OK)
Expand Down Expand Up @@ -457,7 +477,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM

Post("/getsentinfo", FormData("id" -> failed.toString).toEntity) ~>
addCredentials(BasicHttpCredentials("", mockApi().password)) ~>
Route.seal(mockService.getSentInfo) ~>
Route.seal(mockService.route) ~>
check {
assert(handled)
assert(status == OK)
Expand Down

0 comments on commit 8dc64db

Please sign in to comment.