May 10, 2020

Writing a SQL database from scratch in Go: 4. a database/sql driver

Previously in database basics: <! forgive me, for I have sinned >
1. SELECT, INSERT, CREATE and a REPL
2. binary expressions and WHERE filters
3. indexes

In this post, we'll extend gosql to implement the database/sql driver interface. This will allow us to interact with gosql the same way we would interact with any other database.

Here is an example familiar program (stored in cmd/sqlexample/main.go) we'll be able to run:

package main

import (
    "database/sql"
    "fmt"

    _ "github.com/eatonphil/gosql"
)

func main() {
    db, err := sql.Open("postgres", "")
    if err != nil {
        panic(err)
    }
    defer db.Close()

    _, err = db.Query("CREATE TABLE users (name TEXT, age INT);")
    if err != nil {
        panic(err)
    }

    _, err = db.Query("INSERT INTO users VALUES ('Terry', 45);")
    if err != nil {
        panic(err)
    }

    _, err = db.Query("INSERT INTO users VALUES ('Anette', 57);")
    if err != nil {
        panic(err)
    }

    rows, err := db.Query("SELECT name, age FROM users;")
    if err != nil {
        panic(err)
    }

    var name string
    var age uint64
    defer rows.Close()
    for rows.Next() {
        err := rows.Scan(&name, &age)
        if err != nil {
            panic(err)
        }

        fmt.Printf("Name: %s, Age: %d\n", name, age)
    }

    if err = rows.Err(); err != nil {
        panic(err)
    }
}

Our gosql driver will use a single instance of the Backend for all connections.

Aside from that, it is a simple matter of wrapping our existing APIs in structs that implement the database/sql/driver.Driver interface.

This post is largely a discussion of this commit.

Implementing the driver

A driver is registered by calling sql.Register with a driver instance.

We'll add the registration code to an init function in a new file, driver.go:

struct Driver {
    bkd Backend
}

func init() {
    sql.Register("postgres", &Driver{NewMemoryBackend()})
}

According to the Driver interface, we need only implement Open to return an connection instance that implements the database/sql/driver.Conn interface.

type Driver struct {
    bkd Backend
}

func (d *Driver) Open(name string) (driver.Conn, error) {
    return &Conn{d.bkd}, nil
}

func init() {
    sql.Register("postgres", &Driver{NewMemoryBackend()})
}

Implementing the connection

According to the Conn interface, we must implement:

  • Prepare(query string) (driver.Stmt, error) to handle prepared statements
  • Close to handle cleanup
  • and Begin to start a transaction

The connection can also optionally implement Query and Exec.

To simplify things we'll panic on Prepare and on Begin (we don't have transactions yet). There's no cleanup required so we'll do nothing in Close.

type Conn struct {
    bkd Backend
}

func (dc *Conn) Prepare(query string) (driver.Stmt, error) {
    panic("Prepare not implemented")
}

func (dc *Conn) Begin() (driver.Tx, error) {
    panic("Begin not implemented")
}

func (dc *Conn) Close() error {
    return nil
}

The only method we actually need, Query, is not required by the interface. It takes a query string and array of query parameters, returning an instance implementing the database/sql/driver.Rows interface.

To implement Query, we basically copy the logic we had in the cmd/main.go REPL. The only change is that when we return results when handling SELECT, we'll return a struct that implements the database/sql/driver.Rows interface.

database/sql/driver.Rows is not the same type as database/sql.Rows, which may sound more familiar. database/sql/driver.Rows is a simpler, lower-level interface.

If we receive parameterized query arguments, we'll ignore them for now. And if the query involves multiple statements, we'll process only the first statement.

func (dc *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
    if len(args) > 0 {
        // TODO: support parameterization
        panic("Parameterization not supported")
    }

    parser := Parser{}
    ast, err := parser.Parse(query)
    if err != nil {
        return nil, fmt.Errorf("Error while parsing: %s", err)
    }

    // NOTE: ignorning all but the first statement
    stmt := ast.Statements[0]
    switch stmt.Kind {
    case CreateIndexKind:
        err = dc.bkd.CreateIndex(stmt.CreateIndexStatement)
        if err != nil {
            return nil, fmt.Errorf("Error adding index on table: %s", err)
        }
    case CreateTableKind:
        err = dc.bkd.CreateTable(stmt.CreateTableStatement)
        if err != nil {
            return nil, fmt.Errorf("Error creating table: %s", err)
        }
    case DropTableKind:
        err = dc.bkd.DropTable(stmt.DropTableStatement)
        if err != nil {
            return nil, fmt.Errorf("Error dropping table: %s", err)
        }
    case InsertKind:
        err = dc.bkd.Insert(stmt.InsertStatement)
        if err != nil {
            return nil, fmt.Errorf("Error inserting values: %s", err)
        }
    case SelectKind:
        results, err := dc.bkd.Select(stmt.SelectStatement)
        if err != nil {
            return nil, err
        }

        return &Rows{
            rows:    results.Rows,
            columns: results.Columns,
            index:   0,
        }, nil
    }

    return nil, nil
}

Implementing results

According to the Rows interface we must implement:

  • Columns() []string to return an array of columns names
  • Next(dest []Value) error to populate an row array with the next row's worth of cells
  • and Close() error

Our Rows struct will contain the rows and colums as returned from Backend, and will also contain an index field we can use in Next to populate the next row of cells.

type Rows struct {
    columns []ResultColumn
    index   uint64
    rows    [][]Cell
}

func (r *Rows) Columns() []string {}

func (r *Rows) Close() error {}

func (r *Rows) Next(dest []driver.Value) error {}

For Columns we simply need to extract and return the column names from ResultColumn.

func (r *Rows) Columns() []string {
    columns := []string{}
    for _, c := range r.columns {
        columns = append(columns, c.Name)
    }

    return columns
}

For Next we need to iterate over each cell in the current row and retrieve its Go value, storing it in dest. The dest argument is simply a fixed-length array of interface{}, so we'll need no manual conversion.

Once we've reached the last row, the Next contract is to return an io.EOF.

func (r *Rows) Next(dest []driver.Value) error {
    if r.index >= uint64(len(r.rows)) {
        return io.EOF
    }

    row := r.rows[r.index]

    for idx, cell := range row {
        typ := r.columns[idx].Type
        switch typ {
        case IntType:
            i := cell.AsInt()
            if i == nil {
                dest[idx] = i
            } else {
                dest[idx] = *i
            }
        case TextType:
            s := cell.AsText()
            if s == nil {
                dest[idx] = s
            } else {
                dest[idx] = *s
            }
        case BoolType:
            b := cell.AsBool()
            if b == nil {
                dest[idx] = b
            } else {
                dest[idx] = b
            }
        }
    }

    r.index++
    return nil
}

Finally in Close we'll set index higher than the number of rows to force Next to only ever return io.EOF.

func (r *Rows) Close() error {
    r.index = uint64(len(r.rows))
    return nil
}

And that's all the changes needed to implement a database/sql driver! See here for driver.go in full.

Running the example

With the driver in place we can try out the example:

$ go build ./cmd/sqlexample/main.go
$ ./main
Name: Terry, Age: 45
Name: Anette, Age: 57