Skip to content

Commit

Permalink
Extracted the combined recorder from rpc.Conn
Browse files Browse the repository at this point in the history
rpc.Conn now accepts a RecorderFactory rather than an
ObserverFactory. The Recorder interface has the same methods as
Observer, but they can return errors to stop the handling of the
request, which is needed for auditing.

I've kept the Observer interface, since the lack of errors makes
multiplexing simpler. The observer is now embedded in a combined
recorder that forwards messages to it but also passes them on to the
auditlog recorder, which has the opportunity to interrupt the request.
  • Loading branch information
babbageclunk committed Dec 11, 2017
1 parent c388ea9 commit bae76b8
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 129 deletions.
3 changes: 1 addition & 2 deletions api/apiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"gopkg.in/retry.v1"

"github.com/juju/juju/api/base"
"github.com/juju/juju/apiserver/observer"
"github.com/juju/juju/apiserver/params"
"github.com/juju/juju/network"
"github.com/juju/juju/rpc"
Expand Down Expand Up @@ -196,7 +195,7 @@ func Open(info *Info, opts DialOpts) (Connection, error) {
return nil, errors.Trace(err)
}

client := rpc.NewConn(jsoncodec.New(dialResult.conn), observer.None())
client := rpc.NewConn(jsoncodec.New(dialResult.conn), nil)
client.Start()

bakeryClient := opts.BakeryClient
Expand Down
2 changes: 1 addition & 1 deletion api/testing/fakeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func FakeAPIServer(root interface{}) net.Conn {
c0, c1 := net.Pipe()
serverCodec := jsoncodec.NewNet(c1)
serverRPC := rpc.NewConn(serverCodec, nil)
serverRPC.Serve(root, nil)
serverRPC.Serve(root, nil, nil)
serverRPC.Start()
go func() {
<-serverRPC.Dead()
Expand Down
9 changes: 6 additions & 3 deletions apiserver/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ func (a *admin) login(req params.LoginRequest, loginVersion int) (params.LoginRe
modelTag = a.root.model.Tag().String()
}

var recorder *auditlog.Recorder
var auditRecorder *auditlog.Recorder
if authResult.userLogin {
// We only audit connections from humans.
recorder, err = auditlog.NewRecorder(
auditRecorder, err = auditlog.NewRecorder(
a.srv.auditLogger,
auditlog.ConversationArgs{
Who: req.AuthTag,
Expand All @@ -155,7 +155,10 @@ func (a *admin) login(req params.LoginRequest, loginVersion int) (params.LoginRe
}
}

a.root.rpcConn.ServeRoot(apiRoot, recorder, serverError)
recorderFactory := observer.NewRecorderFactory(
a.apiObserver, auditRecorder)

a.root.rpcConn.ServeRoot(apiRoot, recorderFactory, serverError)
return params.LoginResult{
Servers: params.FromNetworkHostsPorts(hostPorts),
ControllerTag: a.root.model.ControllerTag().String(),
Expand Down
2 changes: 1 addition & 1 deletion apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ func (srv *Server) serveConn(
host string,
) error {
codec := jsoncodec.NewWebsocket(wsConn.Conn)
conn := rpc.NewConn(codec, apiObserver)
conn := rpc.NewConn(codec, observer.NewRecorderFactory(apiObserver, nil))

// Note that we don't overwrite modelUUID here because
// newAPIHandler treats an empty modelUUID as signifying
Expand Down
2 changes: 1 addition & 1 deletion apiserver/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestingAPIHandler(c *gc.C, pool *state.StatePool, st *state.State) (*apiHan
statePool: pool,
tag: names.NewMachineTag("0"),
}
h, err := newAPIHandler(srv, st, nil, st.ModelUUID(), "testing.invalid:1234")
h, err := newAPIHandler(srv, st, nil, st.ModelUUID(), 6543, "testing.invalid:1234")
c.Assert(err, jc.ErrorIsNil)
return h, h.getResources()
}
Expand Down
83 changes: 83 additions & 0 deletions apiserver/observer/recorder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2017 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package observer

import (
"encoding/json"

"github.com/juju/errors"

"github.com/juju/juju/core/auditlog"
"github.com/juju/juju/rpc"
)

// NewRecorderFactory makes a new rpc.RecorderFactory to make
// recorders that that will update the observer and the auditlog
// recorder when it records a request or reply. The auditlog recorder
// can be nil.
func NewRecorderFactory(observerFactory rpc.ObserverFactory, recorder *auditlog.Recorder) rpc.RecorderFactory {
return func() rpc.Recorder {
return &combinedRecorder{
observer: observerFactory.RPCObserver(),
recorder: recorder,
}
}
}

// combinedRecorder wraps an observer (which might be a multiplexer)
// up with an auditlog recorder into an rpc.Recorder.
type combinedRecorder struct {
observer rpc.Observer
recorder *auditlog.Recorder
}

// ServerRequest implements rpc.Recorder.
func (cr *combinedRecorder) ServerRequest(hdr *rpc.Header, body interface{}) error {
cr.observer.ServerRequest(hdr, body)
if cr.recorder == nil {
return nil
}
// TODO(babbageclunk): make this configurable.
jsonArgs, err := json.Marshal(body)
if err != nil {
return errors.Trace(err)
}
return errors.Trace(cr.recorder.AddRequest(auditlog.RequestArgs{
RequestID: hdr.RequestId,
Facade: hdr.Request.Type,
Method: hdr.Request.Action,
Version: hdr.Request.Version,
Args: string(jsonArgs),
}))
}

// ServerReply implements rpc.Recorder.
func (cr *combinedRecorder) ServerReply(req rpc.Request, replyHdr *rpc.Header, body interface{}) error {
cr.observer.ServerReply(req, replyHdr, body)
if cr.recorder == nil {
return nil
}
var responseErrors []*auditlog.Error
if replyHdr.Error == "" {
var err error
responseErrors, err = extractErrors(body)
if err != nil {
return errors.Trace(err)
}
} else {
responseErrors = []*auditlog.Error{{
Message: replyHdr.Error,
Code: replyHdr.ErrorCode,
}}
}
return errors.Trace(cr.recorder.AddResponse(auditlog.ResponseErrorsArgs{
RequestID: replyHdr.RequestId,
Errors: responseErrors,
}))
}

func extractErrors(body interface{}) ([]*auditlog.Error, error) {
// TODO(babbageclunk): use reflection to find errors in the response body.
return nil, nil
}
3 changes: 2 additions & 1 deletion apiserver/testing/fakeapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/gorilla/websocket"
"github.com/juju/utils"

"github.com/juju/juju/apiserver/observer"
"github.com/juju/juju/apiserver/observer/fakeobserver"
"github.com/juju/juju/rpc"
"github.com/juju/juju/rpc/jsoncodec"
Expand Down Expand Up @@ -77,7 +78,7 @@ func (srv *Server) serveAPI(w http.ResponseWriter, req *http.Request) {

func (srv *Server) serveConn(wsConn *websocket.Conn, modelUUID string) {
codec := jsoncodec.NewWebsocket(wsConn)
conn := rpc.NewConn(codec, &fakeobserver.Instance{})
conn := rpc.NewConn(codec, observer.NewRecorderFactory(&fakeobserver.Instance{}, nil))

root := allVersions{
rpcreflect.ValueOf(reflect.ValueOf(srv.newRoot(modelUUID))),
Expand Down
6 changes: 0 additions & 6 deletions core/auditlog/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ func NewRecorder(log AuditLog, c ConversationArgs) (*Recorder, error) {

// AddRequest records a method call to the API.
func (r *Recorder) AddRequest(m RequestArgs) error {
if r == nil {
return nil
}
return errors.Trace(r.log.AddRequest(Request{
ConversationID: r.callID,
ConnectionID: r.connectionID,
Expand All @@ -154,9 +151,6 @@ func (r *Recorder) AddRequest(m RequestArgs) error {

// AddResponse records the result of a method call to the API.
func (r *Recorder) AddResponse(m ResponseErrorsArgs) error {
if r == nil {
return nil
}
return errors.Trace(r.log.AddResponse(ResponseErrors{
ConversationID: r.callID,
ConnectionID: r.connectionID,
Expand Down
4 changes: 2 additions & 2 deletions rpc/dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func (s *dispatchSuite) SetUpSuite(c *gc.C) {
s.BaseSuite.SetUpSuite(c)
rpcServer := func(ws *websocket.Conn) {
codec := jsoncodec.NewWebsocket(ws)
conn := rpc.NewConn(codec, &notifier{})
conn := rpc.NewConn(codec, nil)

conn.Serve(&DispatchRoot{}, nil)
conn.Serve(&DispatchRoot{}, nil, nil)
conn.Start()

<-conn.Dead()
Expand Down
8 changes: 8 additions & 0 deletions rpc/observers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ type Observer interface {
ServerReply(req Request, hdr *Header, body interface{})
}

// ObserverFactory is a type which can construct a new Observer.
type ObserverFactory interface {
// RPCObserver will return a new Observer usually constructed
// from the state previously built up in the Observer. The
// returned instance will be utilized per RPC request.
RPCObserver() Observer
}

// NewObserverMultiplexer returns a new ObserverMultiplexer
// with the provided RequestNotifiers.
func NewObserverMultiplexer(rpcObservers ...Observer) *ObserverMultiplexer {
Expand Down
32 changes: 13 additions & 19 deletions rpc/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,11 @@ func (a *CallbackMethods) Factorial(x int64val) (int64val, error) {
}

func (a *ChangeAPIMethods) ChangeAPI() {
a.r.conn.Serve(&changedAPIRoot{}, nil)
a.r.conn.Serve(&changedAPIRoot{}, nil, nil)
}

func (a *ChangeAPIMethods) RemoveAPI() {
a.r.conn.Serve(nil, nil)
a.r.conn.Serve(nil, nil, nil)
}

type changedAPIRoot struct{}
Expand Down Expand Up @@ -987,7 +987,7 @@ func (*rpcSuite) TestBidirectional(c *gc.C) {
client, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
defer closeClient(c, client, srvDone)
clientRoot := &Root{conn: client}
client.Serve(clientRoot, nil)
client.Serve(clientRoot, nil, nil)
var r int64val
err := client.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{12}, &r)
c.Assert(err, jc.ErrorIsNil)
Expand Down Expand Up @@ -1105,12 +1105,13 @@ func newRPCClientServer(
if bidir {
role = roleBoth
}
rpcConn := rpc.NewConn(NewJSONCodec(conn, role), serverNotifier)
recorderFactory := func() rpc.Recorder { return serverNotifier }
rpcConn := rpc.NewConn(NewJSONCodec(conn, role), recorderFactory)
if custroot, ok := root.(*CustomRoot); ok {
rpcConn.ServeRoot(custroot, nil, tfErr)
rpcConn.ServeRoot(custroot, recorderFactory, tfErr)
custroot.root.conn = rpcConn
} else {
rpcConn.Serve(root, tfErr)
rpcConn.Serve(root, recorderFactory, tfErr)
}
if root, ok := root.(*Root); ok {
root.conn = rpcConn
Expand All @@ -1125,9 +1126,9 @@ func newRPCClientServer(
if bidir {
role = roleBoth
}
client = rpc.NewConn(NewJSONCodec(conn, role), &notifier{})
client = rpc.NewConn(NewJSONCodec(conn, role), nil)
client.Start()
return client, srvDone, serverNotifier
return client, server, srvDone, serverNotifier
}

func closeClient(c *gc.C, client *rpc.Conn, srvDone <-chan error) {
Expand Down Expand Up @@ -1225,37 +1226,30 @@ type notifier struct {
serverReplies []replyEvent
}

func (n *notifier) RPCObserver() rpc.Observer {
// For testing, we usually won't want an actual copy of the
// stub. To avoid confusing test failures (e.g. wondering why your
// calls aren't showing up on your stub because the underlying
// code has called DeepCopy) and immense complexity, just return
// the same value.
return n
}

func (n *notifier) reset() {
n.mu.Lock()
defer n.mu.Unlock()
n.serverRequests = nil
n.serverReplies = nil
}

func (n *notifier) ServerRequest(hdr *rpc.Header, body interface{}) {
func (n *notifier) ServerRequest(hdr *rpc.Header, body interface{}) error {
n.mu.Lock()
defer n.mu.Unlock()
n.serverRequests = append(n.serverRequests, requestEvent{
hdr: *hdr,
body: body,
})
return nil
}

func (n *notifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}) {
func (n *notifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}) error {
n.mu.Lock()
defer n.mu.Unlock()
n.serverReplies = append(n.serverReplies, replyEvent{
req: req,
hdr: *hdr,
body: body,
})
return nil
}
Loading

0 comments on commit bae76b8

Please sign in to comment.