Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(plugins): Use wazero instead of wasmtime
  • Loading branch information
kyleconroy committed Dec 5, 2023
commit 5e3d938a3faba144b6de71eb8c446bef9290d1d9
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
github.com/cubicdaiya/gonp v1.0.4
github.com/davecgh/go-spew v1.1.1
github.com/fatih/structtag v1.2.0
Expand All @@ -20,6 +19,7 @@ require (
github.com/riza-io/grpc-go v0.2.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/tetratelabs/wazero v1.5.0
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e
github.com/xeipuuv/gojsonschema v1.2.0
golang.org/x/sync v0.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/g
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down Expand Up @@ -185,6 +183,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=
Expand Down
149 changes: 46 additions & 103 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package wasm

import (
"bytes"
"context"
"crypto/sha256"
"errors"
Expand All @@ -15,10 +16,11 @@ import (
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -70,13 +72,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine, expected)
return r.loadWASM(ctx, cacheDir, expected)
})
if err != nil {
return nil, err
Expand All @@ -85,52 +91,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
if staterr == nil {
data, err := os.ReadFile(modPath)
if err != nil {
return nil, err
}
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
if err != nil {
return nil, err
}

moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
module, err := wasmtime.NewModule(engine, wmod)
moduRegion.End()
if err != nil {
return nil, fmt.Errorf("define wasi: %w", err)
}

err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}

return out, nil
return data, nil
}

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
Expand Down Expand Up @@ -245,72 +206,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
return fmt.Errorf("failed to encode codegen request: %w", err)
}

engine := wasmtime.NewEngine()
module, err := r.loadModule(ctx, engine)
cacheDir, err := cache.PluginsDir()
if err != nil {
return fmt.Errorf("loadModule: %w", err)
return err
}

linker := wasmtime.NewLinker(engine)
if err := linker.DefineWasi(); err != nil {
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero"))
if err != nil {
return err
}

dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
wasmBytes, err := r.loadBytes(ctx)
if err != nil {
return fmt.Errorf("temp dir: %w", err)
return fmt.Errorf("loadModule: %w", err)
}

defer os.RemoveAll(dir)
stdinPath := filepath.Join(dir, "stdin")
stderrPath := filepath.Join(dir, "stderr")
stdoutPath := filepath.Join(dir, "stdout")
config := wazero.NewRuntimeConfig().WithCompilationCache(cache)
rt := wazero.NewRuntimeWithConfig(ctx, config)
defer rt.Close(ctx)

if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
return fmt.Errorf("write file: %w", err)
}

// Configure WASI imports to write stdout into a file.
wasiConfig := wasmtime.NewWasiConfig()
wasiConfig.SetArgv([]string{"plugin.wasm", method})
wasiConfig.SetStdinFile(stdinPath)
wasiConfig.SetStdoutFile(stdoutPath)
wasiConfig.SetStderrFile(stderrPath)
// TODO: Handle error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is Instantiate if you'd like to return the error. Though I think any failure here would be a programming bug, not non-determinstic

wasi_snapshot_preview1.MustInstantiate(ctx, rt)

keys := []string{"SQLC_VERSION"}
vals := []string{info.Version}
for _, key := range r.Env {
keys = append(keys, key)
vals = append(vals, os.Getenv(key))
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
mod, err := rt.CompileModule(ctx, wasmBytes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's possible, it would be nice to rejigger to scope this to Runner, possibly with some map[/* wasm url */ string]wazero.CompiledModule. The compilation cache is good to reuse across executions of the sqlc process itself, but it's also good to only compile once per wasm within a process if possible, the cache key computation isn't trivial. Though if the latter doesn't happen that much maybe it doesn't matter

if err != nil {
return err
}
wasiConfig.SetEnv(keys, vals)

store := wasmtime.NewStore(engine)
store.SetWasi(wasiConfig)
var stderr, stdout bytes.Buffer

linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
err = linker.DefineModule(store, "", module)
linkRegion.End()
if err != nil {
return fmt.Errorf("define wasi: %w", err)
conf := wazero.NewModuleConfig()
conf = conf.WithArgs("plugin.wasm", method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, consider chaining, it's arguably idiomatic for wazero users

conf := wazero.NewModuleConfig().
  WithArgs().
  WithStdin().
  WithStdout().

conf = conf.WithEnv("SQLC_VERSION", info.Version)
for _, key := range r.Env {
conf = conf.WithEnv(key, os.Getenv(key))
}
conf = conf.WithStdin(bytes.NewReader(stdinBlob))
conf = conf.WithStdout(&stdout)
conf = conf.WithStderr(&stderr)

// Run the function
fn, err := linker.GetDefault(store, "")
if err != nil {
return fmt.Errorf("wasi: get default: %w", err)
result, err := rt.InstantiateModule(ctx, mod, conf)
if result != nil {
defer result.Close(ctx)
}

callRegion := trace.StartRegion(ctx, "call _start")
_, err = fn.Call(store)
callRegion.End()

if cerr := checkError(err, stderrPath); cerr != nil {
if cerr := checkError(err, &stderr); cerr != nil {
return cerr
}

// Print WASM stdout
stdoutBlob, err := os.ReadFile(stdoutPath)
stdoutBlob, err := io.ReadAll(&stdout)
if err != nil {
return fmt.Errorf("read file: %w", err)
}
Expand All @@ -331,21 +276,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
return nil, status.Error(codes.Unimplemented, "")
}

func checkError(err error, stderrPath string) error {
func checkError(err error, stderr io.Reader) error {
if err == nil {
return err
}

var wtError *wasmtime.Error
if errors.As(err, &wtError) {
if code, ok := wtError.ExitStatus(); ok {
if code == 0 {
return nil
}
if exitErr, ok := err.(*sys.ExitError); ok {
if exitErr.ExitCode() == 0 {
return nil
}
}

// Print WASM stdout
stderrBlob, rferr := os.ReadFile(stderrPath)
stderrBlob, rferr := io.ReadAll(stderr)
if rferr == nil && len(stderrBlob) > 0 {
return errors.New(string(stderrBlob))
}
Expand Down