Skip to content

Commit

Permalink
Restores ability to define host functions w/o context via reflection (#…
Browse files Browse the repository at this point in the history
…832)

This restores the ability to leave out the initial context parameter
when defining functions with reflection. This is important because some
projects are porting from a different library to wazero, and all the
alternatives are not contextualized.

For example, this project is porting envoy host functions, and the
original definitions (in mosn) don't have a context parameter. By being
lenient, they can migrate easier.

See https://github.com/cisco-open/nasp/blob/6b813482b6177385311fdfc42d92a9d3e096b8f2/pkg/proxywasm/wazero/imports_v1.go

Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt authored Oct 28, 2022
1 parent 1ac6c06 commit d108ce4
Show file tree
Hide file tree
Showing 25 changed files with 115 additions and 69 deletions.
2 changes: 1 addition & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (c *compiledModule) Name() (moduleName string) {
}

// Close implements CompiledModule.Close
func (c *compiledModule) Close(_ context.Context) error {
func (c *compiledModule) Close(context.Context) error {
c.compiledEngine.DeleteCompiledModule(c.module)
// It is possible the underlying may need to return an error later, but in any case this matches api.Module.Close.
return nil
Expand Down
4 changes: 2 additions & 2 deletions examples/import-go/age-calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ func main() {
// host-defined functions, but any name would do.
_, err := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, v uint32) {
WithFunc(func(v uint32) {
fmt.Println("log_i32 >>", v)
}).
Export("log_i32").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 {
WithFunc(func() uint32 {
if envYear, err := strconv.ParseUint(os.Getenv("CURRENT_YEAR"), 10, 64); err == nil {
return uint32(envYear) // Allow env-override to prevent annual test maintenance!
}
Expand Down
2 changes: 1 addition & 1 deletion examples/namespace/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type counter struct {
counter uint32
}

func (e *counter) getAndIncrement(context.Context) (ret uint32) {
func (e *counter) getAndIncrement() (ret uint32) {
ret = e.counter
e.counter++
return
Expand Down
2 changes: 1 addition & 1 deletion experimental/logging/log_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func Test_loggingListener(t *testing.T) {

var out bytes.Buffer
lf := logging.NewLoggingListenerFactory(&out)
fn := func(context.Context) {}
fn := func() {}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion imports/assemblyscript/assemblyscript_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func Example_functionExporter() {
// First construct your own module builder for "env"
envBuilder := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 { return 1 }).
WithFunc(func() uint32 { return 1 }).
Export("get_int")

// Now, add AssemblyScript special function imports into it.
Expand Down
2 changes: 1 addition & 1 deletion imports/emscripten/emscripten_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Example_functionExporter() {
// you need.
envBuilder := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 { return 1 }).
WithFunc(func() uint32 { return 1 }).
Export("get_int")

// Now, add Emscripten special function imports into it.
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/compiler/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) {

const hostModuleName = "env"
const hostFnName = "grow_and_shrink_goroutine_stack"
hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func(context.Context) {
hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func() {
// This function aggressively grow the goroutine stack by recursively
// calling the function many times.
var callNum = 1000
Expand Down
2 changes: 1 addition & 1 deletion internal/integration_test/engine/adhoc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func testGlobalExtend(t *testing.T, r wazero.Runtime) {
}

func testUnreachable(t *testing.T, r wazero.Runtime) {
callUnreachable := func(context.Context) {
callUnreachable := func() {
panic("panic in host function")
}

Expand Down
4 changes: 2 additions & 2 deletions internal/integration_test/vs/wasmedge/wasmedge.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (r *wasmedgeRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig)
return
}

func (r *wasmedgeRuntime) Close(_ context.Context) error {
func (r *wasmedgeRuntime) Close(context.Context) error {
if conf := r.conf; conf != nil {
conf.Release()
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func (m *wasmedgeModule) WriteMemory(_ context.Context, offset uint32, bytes []b
return nil
}

func (m *wasmedgeModule) Close(_ context.Context) error {
func (m *wasmedgeModule) Close(context.Context) error {
if env := m.env; env != nil {
env.Release()
}
Expand Down
4 changes: 2 additions & 2 deletions internal/integration_test/vs/wasmer/wasmer.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (r *wasmerRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig) (m
return
}

func (r *wasmerRuntime) Close(_ context.Context) error {
func (r *wasmerRuntime) Close(context.Context) error {
r.engine = nil
return nil
}
Expand Down Expand Up @@ -195,7 +195,7 @@ func (m *wasmerModule) Memory() []byte {
return m.mem.Data()
}

func (m *wasmerModule) Close(_ context.Context) error {
func (m *wasmerModule) Close(context.Context) error {
if instance := m.instance; instance != nil {
instance.Close()
}
Expand Down
4 changes: 2 additions & 2 deletions internal/integration_test/vs/wasmtime/wasmtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (r *wasmtimeRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig)
return
}

func (r *wasmtimeRuntime) Close(_ context.Context) error {
func (r *wasmtimeRuntime) Close(context.Context) error {
r.engine = nil
return nil // wasmtime only closes via finalizer
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func (m *wasmtimeModule) WriteMemory(_ context.Context, offset uint32, bytes []b
return nil
}

func (m *wasmtimeModule) Close(_ context.Context) error {
func (m *wasmtimeModule) Close(context.Context) error {
m.store = nil
m.instance = nil
m.funcs = nil
Expand Down
2 changes: 1 addition & 1 deletion internal/sys/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (c *FSContext) CloseFile(_ context.Context, fd uint32) bool {
}

// Close implements io.Closer
func (c *FSContext) Close(_ context.Context) (err error) {
func (c *FSContext) Close(context.Context) (err error) {
// Close any files opened in this context
for fd, entry := range c.openedFiles {
delete(c.openedFiles, fd)
Expand Down
2 changes: 1 addition & 1 deletion internal/testing/enginetest/enginetest.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ const (
callImportCallDivByGoName = "call_import->" + callDivByGoName
)

func divByGo(_ context.Context, d uint32) uint32 {
func divByGo(d uint32) uint32 {
if d == math.MaxUint32 {
panic(errors.New("host-function panic"))
}
Expand Down
3 changes: 1 addition & 2 deletions internal/wasm/binary/encoder_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package binary

import (
"context"
"testing"

"github.com/tetratelabs/wazero/internal/leb128"
Expand Down Expand Up @@ -211,7 +210,7 @@ func TestModule_Encode(t *testing.T) {

func TestModule_Encode_HostFunctionSection_Unsupported(t *testing.T) {
// We don't currently have an approach to serialize reflect.Value pointers
fn := func(context.Context) {}
fn := func() {}

captured := require.CapturePanic(func() {
EncodeModule(&wasm.Module{
Expand Down
3 changes: 1 addition & 2 deletions internal/wasm/function_definition_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package wasm

import (
"context"
"testing"

"github.com/tetratelabs/wazero/api"
Expand All @@ -10,7 +9,7 @@ import (

func TestModule_BuildFunctionDefinitions(t *testing.T) {
nopCode := &Code{Body: []byte{OpcodeEnd}}
fn := func(context.Context) {}
fn := func() {}
tests := []struct {
name string
m *Module
Expand Down
10 changes: 5 additions & 5 deletions internal/wasm/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (g *mutableGlobal) Type() api.ValueType {
}

// Get implements the same method as documented on api.Global.
func (g *mutableGlobal) Get(_ context.Context) uint64 {
func (g *mutableGlobal) Get(context.Context) uint64 {
return g.g.Val
}

Expand Down Expand Up @@ -54,7 +54,7 @@ func (g globalI32) Type() api.ValueType {
}

// Get implements the same method as documented on api.Global.
func (g globalI32) Get(_ context.Context) uint64 {
func (g globalI32) Get(context.Context) uint64 {
return uint64(g)
}

Expand All @@ -74,7 +74,7 @@ func (g globalI64) Type() api.ValueType {
}

// Get implements the same method as documented on api.Global.
func (g globalI64) Get(_ context.Context) uint64 {
func (g globalI64) Get(context.Context) uint64 {
return uint64(g)
}

Expand All @@ -94,7 +94,7 @@ func (g globalF32) Type() api.ValueType {
}

// Get implements the same method as documented on api.Global.
func (g globalF32) Get(_ context.Context) uint64 {
func (g globalF32) Get(context.Context) uint64 {
return uint64(g)
}

Expand All @@ -114,7 +114,7 @@ func (g globalF64) Type() api.ValueType {
}

// Get implements the same method as documented on api.Global.
func (g globalF64) Get(_ context.Context) uint64 {
func (g globalF64) Get(context.Context) uint64 {
return uint64(g)
}

Expand Down
58 changes: 39 additions & 19 deletions internal/wasm/gofunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ import (
"github.com/tetratelabs/wazero/api"
)

type paramsKind byte

const (
paramsKindNoContext paramsKind = iota
paramsKindContext
paramsKindContextModule
)

// Below are reflection code to get the interface type used to parse functions and set values.

var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem()
Expand Down Expand Up @@ -46,6 +54,7 @@ var _ api.GoFunction = (*reflectGoFunction)(nil)

type reflectGoFunction struct {
fn *reflect.Value
pk paramsKind
params, results []ValueType
}

Expand All @@ -55,12 +64,16 @@ func (f *reflectGoFunction) EqualTo(that interface{}) bool {
return false
} else {
// TODO compare reflect pointers
return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
return f.pk == f2.pk &&
bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
}
}

// Call implements the same method as documented on api.GoFunction.
func (f *reflectGoFunction) Call(ctx context.Context, params []uint64) []uint64 {
if f.pk == paramsKindNoContext {
ctx = nil
}
return callGoFunc(ctx, nil, f.fn, params)
}

Expand Down Expand Up @@ -93,8 +106,11 @@ func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params [
if tp.NumIn() != 0 {
in = make([]reflect.Value, tp.NumIn())

i := 1
in[0] = newContextVal(ctx)
i := 0
if ctx != nil {
in[0] = newContextVal(ctx)
i++
}
if mod != nil {
in[1] = newModuleVal(mod)
i++
Expand Down Expand Up @@ -175,15 +191,19 @@ func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code
return
}

needsMod, needsErr := needsModule(p)
if needsErr != nil {
err = needsErr
pk, kindErr := kind(p)
if kindErr != nil {
err = kindErr
return
}

pOffset := 1 // ctx
if needsMod {
pOffset = 2 // ctx, mod
pOffset := 0
switch pk {
case paramsKindNoContext:
case paramsKindContext:
pOffset = 1
case paramsKindContextModule:
pOffset = 2
}

pCount := p.NumIn() - pOffset
Expand Down Expand Up @@ -234,30 +254,30 @@ func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code
}

code = &Code{IsHostFunction: true}
if needsMod {
if pk == paramsKindContextModule {
code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
} else {
code.GoFunc = &reflectGoFunction{fn: &fnV, params: params, results: results}
code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
}
return
}

func needsModule(p reflect.Type) (bool, error) {
func kind(p reflect.Type) (paramsKind, error) {
pCount := p.NumIn()
if pCount == 0 {
return false, errors.New("invalid signature: context.Context must be param[0]")
}
if p.In(0).Kind() == reflect.Interface {
if pCount > 0 && p.In(0).Kind() == reflect.Interface {
p0 := p.In(0)
if p0.Implements(moduleType) {
return false, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
} else if p0.Implements(goContextType) {
if pCount >= 2 && p.In(1).Implements(moduleType) {
return true, nil
return paramsKindContextModule, nil
}
return paramsKindContext, nil
}
}
return false, nil
// Without context param allows portability with reflective runtimes.
// This allows people to more easily port to wazero.
return paramsKindNoContext, nil
}

func getTypeOf(kind reflect.Kind) (ValueType, bool) {
Expand Down
Loading

0 comments on commit d108ce4

Please sign in to comment.