Skip to content

Commit

Permalink
Simplify EnvPersistence.
Browse files Browse the repository at this point in the history
  • Loading branch information
ericsnowcurrently committed May 19, 2016
1 parent 67425b1 commit 642a288
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 193 deletions.
4 changes: 2 additions & 2 deletions component/all/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ func (payloads) registerState() {
return payloadstate.NewUnitPayloads(persist, unit, machine), nil
}

newEnvPayloads := func(db state.Persistence, persist state.PayloadsEnvPersistence) (state.EnvPayloads, error) {
envPersist := persistence.NewEnvPersistence(db, persist)
newEnvPayloads := func(db state.Persistence, st state.PayloadsEnvPersistence) (state.EnvPayloads, error) {
envPersist := persistence.NewEnvPersistence(db, st)
envPayloads := payloadstate.EnvPayloads{
Persist: envPersist,
}
Expand Down
61 changes: 14 additions & 47 deletions payload/persistence/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,81 +12,48 @@ import (
// EnvPersistenceEntities provides all the information needed to produce
// a new EnvPersistence value.
type EnvPersistenceEntities interface {
// Machines builds the list of the names that identify
// all machines in State.
Machines() ([]string, error)

// MachineUnits builds the list of names that identify all units
// for a given machine.
MachineUnits(machineName string) ([]string, error)
}

// unitPersistence describes the per-unit functionality needed
// for env persistence.
type unitPersistence interface {
// ListAll returns all payloads associated with the unit.
ListAll() ([]payload.Payload, error)
// AssignedMachineID the machine to which the identfies unit is assigned.
AssignedMachineID(unitName string) (string, error)
}

// EnvPersistence provides the persistence functionality for the
// Juju environment as a whole.
type EnvPersistence struct {
db *Persistence
st EnvPersistenceEntities

newUnitPersist func(name string) unitPersistence
}

// NewEnvPersistence wraps the "db" in a new EnvPersistence.
func NewEnvPersistence(db PersistenceBase, st EnvPersistenceEntities) *EnvPersistence {
return &EnvPersistence{
db: NewPersistence(db, ""),
st: st,
newUnitPersist: func(name string) unitPersistence {
return NewPersistence(db, name)
},
}
}

// ListAll returns the list of all payloads in the environment.
func (ep *EnvPersistence) ListAll() ([]payload.FullPayloadInfo, error) {
logger.Tracef("listing all payloads")

machines, err := ep.st.Machines()
docs, err := ep.db.allModelPayloads()
if err != nil {
return nil, errors.Trace(err)
}

var payloads []payload.FullPayloadInfo
for _, machine := range machines {
units, err := ep.st.MachineUnits(machine)
if err != nil {
return nil, errors.Trace(err)
}

for _, unit := range units {
persist := ep.newUnitPersist(unit)

unitPayloads, err := listUnit(persist, unit, machine)
unitMachines := make(map[string]string)
var fullPayloads []payload.FullPayloadInfo
for _, doc := range docs {
machineID, ok := unitMachines[doc.UnitID]
if !ok {
machineID, err = ep.st.AssignedMachineID(doc.UnitID)
if err != nil {
return nil, errors.Trace(err)
}
payloads = append(payloads, unitPayloads...)
unitMachines[doc.UnitID] = machineID
}
}
return payloads, nil
}

// listUnit returns all the payloads for the given unit.
func listUnit(persist unitPersistence, unit, machine string) ([]payload.FullPayloadInfo, error) {
payloads, err := persist.ListAll()
if err != nil {
return nil, errors.Trace(err)
}

var fullPayloads []payload.FullPayloadInfo
for _, pl := range payloads {
fullPayloads = append(fullPayloads, payload.FullPayloadInfo{
Payload: pl,
Machine: machine,
Payload: doc.payload(doc.UnitID),
Machine: machineID,
})
}
return fullPayloads, nil
Expand Down
121 changes: 16 additions & 105 deletions payload/persistence/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func (s *envPersistenceSuite) SetUpTest(c *gc.C) {
s.BaseSuite.SetUpTest(c)

s.base = &stubEnvPersistenceBase{
PersistenceBase: s.State,
stub: s.Stub,
fakeStatePersistence: s.State,
stub: s.Stub,
}
}

Expand Down Expand Up @@ -57,49 +57,28 @@ func (s *envPersistenceSuite) TestListAllOkay(c *gc.C) {
s.base.setPayloads(p1, p2)

persist := NewEnvPersistence(s.base, s.base)
persist.newUnitPersist = s.base.newUnitPersistence

payloads, err := persist.ListAll()
c.Assert(err, jc.ErrorIsNil)

checkPayloads(c, payloads, p1, p2)
s.Stub.CheckCallNames(c,
"Machines",

"MachineUnits",

"MachineUnits",
"newUnitPersistence",
"ListAll",
"newUnitPersistence",
"ListAll",

"MachineUnits",
"newUnitPersistence",
"ListAll",
"All",
"AssignedMachineID",
)
}

func (s *envPersistenceSuite) TestListAllEmpty(c *gc.C) {
s.base.setUnits("0")
s.base.setUnits("1", "a-service/0", "a-service/1")
persist := NewEnvPersistence(s.base, s.base)
persist.newUnitPersist = s.base.newUnitPersistence

payloads, err := persist.ListAll()
c.Assert(err, jc.ErrorIsNil)

c.Check(payloads, gc.HasLen, 0)
s.Stub.CheckCallNames(c,
"Machines",

"MachineUnits",

"MachineUnits",
"newUnitPersistence",
"ListAll",
"newUnitPersistence",
"ListAll",
"All",
)
}

Expand All @@ -108,7 +87,6 @@ func (s *envPersistenceSuite) TestListAllFailed(c *gc.C) {
s.Stub.SetErrors(failure)

persist := NewEnvPersistence(s.base, s.base)
persist.newUnitPersist = s.base.newUnitPersistence

_, err := persist.ListAll()

Expand Down Expand Up @@ -151,101 +129,34 @@ func checkPayloads(c *gc.C, payloads []payload.FullPayloadInfo, expectedList ...
}

type stubEnvPersistenceBase struct {
PersistenceBase
*fakeStatePersistence
stub *testing.Stub
machines []string
units map[string]map[string]bool
unitPersists map[string]*stubUnitPersistence
unitMachines map[string]string
}

func (s *stubEnvPersistenceBase) setPayloads(payloads ...payload.FullPayloadInfo) {
if s.unitPersists == nil && len(payloads) > 0 {
s.unitPersists = make(map[string]*stubUnitPersistence)
}

for _, pl := range payloads {
s.setUnits(pl.Machine, pl.Unit)

unitPayloads := s.unitPersists[pl.Unit]
if unitPayloads == nil {
unitPayloads = &stubUnitPersistence{stub: s.stub}
s.unitPersists[pl.Unit] = unitPayloads
}

unitPayloads.setPayloads(pl.Payload)
doc := newPayloadDoc(pl.Unit, "0", pl.Payload)
s.SetDocs(doc)
}
}

func (s *stubEnvPersistenceBase) setUnits(machine string, units ...string) {
if s.units == nil {
s.units = make(map[string]map[string]bool)
if s.unitMachines == nil {
s.unitMachines = make(map[string]string)
}
if _, ok := s.units[machine]; !ok {
s.machines = append(s.machines, machine)
s.units[machine] = make(map[string]bool)
}

for _, unit := range units {
s.units[machine][unit] = true
}
}

func (s *stubEnvPersistenceBase) newUnitPersistence(unit string) unitPersistence {
s.stub.AddCall("newUnitPersistence", unit)
s.stub.NextErr() // pop one off

persist, ok := s.unitPersists[unit]
if !ok {
if s.unitPersists == nil {
s.unitPersists = make(map[string]*stubUnitPersistence)
}
persist = &stubUnitPersistence{stub: s.stub}
s.unitPersists[unit] = persist
s.unitMachines[unit] = machine
}
return persist
}

func (s *stubEnvPersistenceBase) Machines() ([]string, error) {
s.stub.AddCall("Machines")
if err := s.stub.NextErr(); err != nil {
return nil, errors.Trace(err)
}

var names []string
for _, name := range s.machines {
names = append(names, name)
}
return names, nil
}

func (s *stubEnvPersistenceBase) MachineUnits(machine string) ([]string, error) {
s.stub.AddCall("MachineUnits", machine)
if err := s.stub.NextErr(); err != nil {
return nil, errors.Trace(err)
}

var units []string
for unit := range s.units[machine] {
units = append(units, unit)
}
return units, nil
}

type stubUnitPersistence struct {
stub *testing.Stub

payloads []payload.Payload
}

func (s *stubUnitPersistence) setPayloads(payloads ...payload.Payload) {
s.payloads = append(s.payloads, payloads...)
}

func (s *stubUnitPersistence) ListAll() ([]payload.Payload, error) {
s.stub.AddCall("ListAll")
func (s *stubEnvPersistenceBase) AssignedMachineID(unit string) (string, error) {
s.stub.AddCall("AssignedMachineID", unit)
if err := s.stub.NextErr(); err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}

return s.payloads, nil
return s.unitMachines[unit], nil
}
11 changes: 8 additions & 3 deletions payload/persistence/fakes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (sp fakeStatePersistence) CheckNoOps(c *gc.C) {
}

func (sp fakeStatePersistence) All(collName string, query, docs interface{}) error {
actual := docs.(*[]payloadDoc)

sp.AddCall("All", collName, query, docs)
if err := sp.NextErr(); err != nil {
return errors.Trace(err)
Expand All @@ -72,8 +74,12 @@ func (sp fakeStatePersistence) All(collName string, query, docs interface{}) err
var ids []string
elems := query.(bson.D)
if len(elems) < 1 {
err := errors.Errorf("bad query %v", query)
panic(err)
var all []payloadDoc
for _, doc := range sp.docs {
all = append(all, *doc)
}
*actual = all
return nil
}
switch elems[0].Name {
case "_id":
Expand Down Expand Up @@ -103,7 +109,6 @@ func (sp fakeStatePersistence) All(collName string, query, docs interface{}) err
}
found = append(found, *doc)
}
actual := docs.(*[]payloadDoc)
*actual = found
return nil
}
Expand Down
22 changes: 19 additions & 3 deletions payload/persistence/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ func (pp Persistence) allID(query bson.D, docs interface{}) error {
}

func (pp Persistence) payloadID(name string) string {
return fmt.Sprintf("payload#%s#%s", pp.unit, name)
return payloadID(pp.unit, name)
}

func payloadID(unit, name string) string {
return fmt.Sprintf("payload#%s#%s", unit, name)
}

func (pp Persistence) newInsertPayloadOps(id string, p payload.Payload) []txn.Op {
Expand Down Expand Up @@ -144,7 +148,11 @@ func (d payloadDoc) match(name, rawID string) bool {
}

func (pp Persistence) newPayloadDoc(stID string, p payload.Payload) *payloadDoc {
id := pp.payloadID(p.Name)
return newPayloadDoc(pp.unit, stID, p)
}

func newPayloadDoc(unit, stID string, p payload.Payload) *payloadDoc {
id := payloadID(unit, p.Name)

definition := p.PayloadClass

Expand All @@ -153,7 +161,7 @@ func (pp Persistence) newPayloadDoc(stID string, p payload.Payload) *payloadDoc

return &payloadDoc{
DocID: id,
UnitID: pp.unit,
UnitID: unit,
Name: definition.Name,

StateID: stID,
Expand All @@ -168,6 +176,14 @@ func (pp Persistence) newPayloadDoc(stID string, p payload.Payload) *payloadDoc
}
}

func (pp Persistence) allModelPayloads() ([]payloadDoc, error) {
var docs []payloadDoc
if err := pp.all(nil, &docs); err != nil {
return nil, errors.Trace(err)
}
return docs, nil
}

func (pp Persistence) allPayloads() (map[string]payloadDoc, error) {
var docs []payloadDoc
query := bson.D{{"unitid", pp.unit}}
Expand Down
Loading

0 comments on commit 642a288

Please sign in to comment.