Skip to content

Commit

Permalink
refactor(annotations): refactor queries into the same file as state
Browse files Browse the repository at this point in the history
Refactor the annotations domain to include queries in the same file as
the state methods. Also remove some of the conditional logic where it
was no longer necessery.

This change was motivated by an update to SQLair that checks if an
argument to prepare is used. In this domain, different queries were
returns from functions then used with a single prepare. This refator
prepares all the queries next to their declarations. This means the
types can be easily matched to the query.
  • Loading branch information
Aflynn50 committed Dec 20, 2024
1 parent 71c7981 commit 589e95b
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 185 deletions.
124 changes: 0 additions & 124 deletions domain/annotation/state/query.go

This file was deleted.

163 changes: 118 additions & 45 deletions domain/annotation/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package state

import (
"context"
"fmt"

"github.com/canonical/sqlair"
"github.com/juju/collections/transform"
Expand Down Expand Up @@ -33,25 +34,15 @@ func NewState(factory database.TxnRunnerFactory) *State {
// from the database.
// If no annotations are found, an empty map is returned.
func (st *State) GetAnnotations(ctx context.Context, id annotations.ID) (map[string]string, error) {
getAnnotationsQuery, err := getAnnotationQueryForID(id)
if err != nil {
return nil, errors.Capture(err)
}

annotationUUIDParam := annotationUUID{}
getAnnotationsStmt, err := st.Prepare(getAnnotationsQuery, Annotation{}, annotationUUIDParam)
if err != nil {
return nil, errors.Errorf("preparing get annotations query for ID: %q: %w", id.Name, err)
}

if id.Kind == annotations.KindModel {
return st.getAnnotationsForModel(ctx, id, getAnnotationsStmt)
return st.getAnnotationsForModel(ctx)
}
return st.getAnnotationsForID(ctx, id, getAnnotationsStmt, annotationUUIDParam)
return st.getAnnotationsForID(ctx, id)
}

// GetAnnotations will retrieve all the annotations associated with the given ID
// from the database.
// GetCharmAnnotations will retrieve all the annotations associated with the
// given ID from the database.
// If no annotations are found, an empty map is returned.
func (st *State) GetCharmAnnotations(ctx context.Context, id annotation.GetCharmArgs) (map[string]string, error) {
db, err := st.DB()
Expand Down Expand Up @@ -99,12 +90,19 @@ WHERE c.name = $charmArgs.name AND c.revision = $charmArgs.revision;
// This method is specialized to Models as opposed to the other Kinds because we
// keep annotations per model, so we don't need to try to find the UUID of the
// given ID (the model).
func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID, getAnnotationsStmt *sqlair.Statement) (map[string]string, error) {
func (st *State) getAnnotationsForModel(ctx context.Context) (map[string]string, error) {
db, err := st.DB()
if err != nil {
return nil, errors.Capture(err)
}

getAnnotationsStmt, err := st.Prepare(`
SELECT (key, value) AS (&Annotation.*)
FROM annotation_model`, Annotation{})
if err != nil {
return nil, errors.Errorf("preparing get annotations query for model: %w", err)
}

var annotationsResults []Annotation
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
err := tx.Query(ctx, getAnnotationsStmt).GetAll(&annotationsResults)
Expand All @@ -114,7 +112,7 @@ func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID,
return err
})
if err != nil {
return nil, errors.Errorf("loading annotations for ID: %q: %w", id.Name, err)
return nil, errors.Errorf("loading annotations for model: %w", err)
}

annotations := transform.SliceToMap(annotationsResults, func(a Annotation) (string, string) { return a.Key, a.Value })
Expand All @@ -128,16 +126,31 @@ func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID,
// This is separate from the getAnnotationsForModel because for non-model ID
// Kinds we need to find the UUID of the ID before we retrieve annotations from
// the corresponding annotation table.
func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, getAnnotationsStmt *sqlair.Statement, annotationUUIDParam annotationUUID) (map[string]string, error) {
func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID) (map[string]string, error) {
db, err := st.DB()
if err != nil {
return nil, errors.Capture(err)
}

tableName, err := annotationTableNameFromID(id)
if err != nil {
return nil, errors.Capture(err)
}
getAnnotationsQuery := fmt.Sprintf(`
SELECT (key, value) AS (&Annotation.*)
FROM %s
WHERE uuid = $annotationUUID.uuid`, tableName)

getAnnotationsStmt, err := st.Prepare(getAnnotationsQuery, Annotation{}, annotationUUID{})
if err != nil {
return nil, errors.Errorf("preparing get annotations query for ID: %q: %w", id.Name, err)
}

kindQuery, kindQueryParam, err := uuidQueryForID(id)
if err != nil {
return nil, errors.Errorf("preparing get annotations query for ID: %q: %w", id.Name, err)
}
annotationUUIDParam := annotationUUID{UUID: id.String()}
kindQueryStmt, err := st.Prepare(kindQuery, kindQueryParam, annotationUUIDParam)
if err != nil {
return nil, errors.Errorf("preparing get annotations query for ID: %q: %w", id.Name, err)
Expand Down Expand Up @@ -179,28 +192,11 @@ func (st *State) SetAnnotations(
id annotations.ID,
values map[string]string,
) error {
insertQuery, err := setAnnotationQueryForID(id)
if err != nil {
return errors.Capture(err)
}
deleteQuery, err := deleteAnnotationsQueryForID(id)
if err != nil {
return errors.Capture(err)
}

insertStmt, err := st.Prepare(insertQuery, Annotation{}, annotationUUID{})
if err != nil {
return errors.Errorf("preparing set annotations query for ID: %q: %w", id.Name, err)
}
deleteStmt, err := st.Prepare(deleteQuery, annotationUUID{})
if err != nil {
return errors.Errorf("preparing set annotations query for ID: %q: %w", id.Name, err)
}

if id.Kind == annotations.KindModel {
return st.setAnnotationsForModel(ctx, id, values, insertStmt, deleteStmt)
return st.setAnnotationsForModel(ctx, values)
}
return st.setAnnotationsForID(ctx, id, values, insertStmt, deleteStmt)
return st.setAnnotationsForID(ctx, id, values)

}

// setAnnotationsForID associates key/value pairs with the given ID.
Expand All @@ -209,14 +205,35 @@ func (st *State) SetAnnotations(
// corresponding annotation table.
func (st *State) setAnnotationsForID(ctx context.Context, id annotations.ID,
toInsert map[string]string,
setAnnotationsStmt *sqlair.Statement,
deleteAnnotationsStmt *sqlair.Statement,
) error {
db, err := st.DB()
if err != nil {
return errors.Capture(err)
}

tableName, err := annotationTableNameFromID(id)
if err != nil {
return errors.Capture(err)
}
insertQuery := fmt.Sprintf(`
INSERT INTO %s (uuid, key, value)
VALUES ($annotationUUID.uuid, $Annotation.key, $Annotation.value)
ON CONFLICT(uuid, key) DO UPDATE SET value=$Annotation.value`, tableName)

setAnnotationsStmt, err := st.Prepare(insertQuery, Annotation{}, annotationUUID{})
if err != nil {
return errors.Errorf("preparing set annotations query for ID: %q: %w", id.Name, err)
}

deleteQuery := fmt.Sprintf(`
DELETE FROM %s
WHERE uuid = $annotationUUID.uuid`, tableName)

deleteAnnotationsStmt, err := st.Prepare(deleteQuery, annotationUUID{})
if err != nil {
return errors.Errorf("preparing set annotations query for ID: %q: %w", id.Name, err)
}

kindQuery, kindQueryParam, err := uuidQueryForID(id)
if err != nil {
return errors.Errorf("preparing uuid retrieval query for ID: %q: %w", id.Name, err)
Expand Down Expand Up @@ -261,33 +278,44 @@ func (st *State) setAnnotationsForID(ctx context.Context, id annotations.ID,
// This is specialized to models as opposed to the other Kinds because we keep
// annotations per model, so we don't need to try to find the uuid of the given
// id (the model).
func (st *State) setAnnotationsForModel(ctx context.Context, id annotations.ID,
func (st *State) setAnnotationsForModel(ctx context.Context,
toInsert map[string]string,
setAnnotationsStmt *sqlair.Statement,
deleteAnnotationsStmt *sqlair.Statement,
) error {
db, err := st.DB()
if err != nil {
return errors.Capture(err)
}

setAnnotationsStmt, err := st.Prepare(`
INSERT INTO annotation_model (key, value)
VALUES ($Annotation.*)
ON CONFLICT(key) DO UPDATE SET value=$Annotation.value`, Annotation{})
if err != nil {
return errors.Errorf("preparing set annotations query for model: %w", err)
}
deleteAnnotationsStmt, err := st.Prepare(`
DELETE FROM annotation_model`)
if err != nil {
return errors.Errorf("preparing set annotations query for model: %w", err)
}

err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
if err := tx.Query(ctx, deleteAnnotationsStmt).Run(); err != nil {
return errors.Errorf("unsetting annotations for ID: %s: %w", id.Name, err)
return errors.Errorf("unsetting annotations for model: %w", err)
}

var annotationParam Annotation
for key, value := range toInsert {
annotationParam.Key = key
annotationParam.Value = value
if err := tx.Query(ctx, setAnnotationsStmt, annotationParam).Run(); err != nil {
return errors.Errorf("setting annotations for ID: %s: %w", id.Name, err)
return errors.Errorf("setting annotations for model: %w", err)
}
}
return nil
})
if err != nil {
return errors.Errorf("setting model annotations with uuid: %q: %w", id.Name, err)
return errors.Errorf("setting model annotations: %w", err)
}
return nil
}
Expand Down Expand Up @@ -363,3 +391,48 @@ VALUES ($annotationUUID.*, $Annotation.*)
}
return nil
}

// uuidQueryForID generates a query and parameters for getting the uuid for a
// given annotation ID.
func uuidQueryForID(id annotations.ID) (string, sqlair.M, error) {
switch id.Kind {
case annotations.KindMachine:
return `SELECT &annotationUUID.uuid FROM machine WHERE name = $M.entity_name`,
sqlair.M{"entity_name": id.Name}, nil
case annotations.KindUnit:
return `SELECT &annotationUUID.uuid FROM unit WHERE name = $M.entity_name`,
sqlair.M{"entity_name": id.Name}, nil
case annotations.KindApplication:
return `SELECT &annotationUUID.uuid FROM application WHERE name = $M.entity_name`,
sqlair.M{"entity_name": id.Name}, nil
case annotations.KindStorage:
return `SELECT &annotationUUID.uuid FROM storage_instance WHERE name = $M.entity_name`,
sqlair.M{"entity_name": id.Name}, nil
case annotations.KindModel:
return `SELECT &annotationUUID.uuid FROM model WHERE name = $M.entity_name`,
sqlair.M{"entity_name": id.Name}, nil
default:
return "", nil, errors.Errorf("cannot generate uuid for kind: %q", id.Kind)
}
}

// annotationTableNameFromID keeps the table names for the different annotation
// tables.
func annotationTableNameFromID(id annotations.ID) (string, error) {
var tableName string
switch id.Kind {
case annotations.KindMachine:
tableName = "annotation_machine"
case annotations.KindUnit:
tableName = "annotation_unit"
case annotations.KindApplication:
tableName = "annotation_application"
case annotations.KindStorage:
tableName = "annotation_storage_instance"
case annotations.KindModel:
tableName = "annotation_model"
default:
return "", errors.Errorf("%q: %w", id.Kind, annotationerrors.UnknownKind)
}
return tableName, nil
}
Loading

0 comments on commit 589e95b

Please sign in to comment.