-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontext_test.go
105 lines (81 loc) · 2.12 KB
/
context_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package typedcontext_test
import (
"context"
"testing"
typedcontext "github.com/mraerino/typed-context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type Welcome string
type RequestID string
type MyStruct struct {
ID string
}
func TestGetSet(t *testing.T) {
ctx := context.Background()
value := Welcome("Hello World")
ctx = typedcontext.Set(ctx, value)
actual, ok := typedcontext.Get[Welcome](ctx)
require.True(t, ok)
assert.Equal(t, "Hello World", string(actual))
// override
value2 := Welcome("Goodbye")
ctx = typedcontext.Set(ctx, value2)
actual2, ok := typedcontext.Get[Welcome](ctx)
require.True(t, ok)
assert.Equal(t, "Goodbye", string(actual2))
// different type
value3 := RequestID("0c4f7d51-af18-4475-9fdc-5f022fb8079c")
ctx = typedcontext.Set(ctx, value3)
actual3, ok := typedcontext.Get[Welcome](ctx)
require.True(t, ok)
assert.Equal(t, "Goodbye", string(actual3))
actual4, ok := typedcontext.Get[RequestID](ctx)
require.True(t, ok)
assert.Equal(t, "0c4f7d51-af18-4475-9fdc-5f022fb8079c", string(actual4))
// pointer
value4 := &MyStruct{
ID: "hello",
}
ctx = typedcontext.Set(ctx, value4)
actual5, ok := typedcontext.Get[*MyStruct](ctx)
require.True(t, ok)
assert.Equal(t, "hello", actual5.ID)
}
type ctxKey uint64
const (
requestIDKey ctxKey = iota
)
var valStdlib any
func BenchmarkStdlib(b *testing.B) {
ctx := context.Background()
ctx = context.WithValue(ctx, requestIDKey, "0c4f7d51-af18-4475-9fdc-5f022fb8079c")
var val any
for n := 0; n < b.N; n++ {
val = ctx.Value(requestIDKey)
if val == nil {
b.Fatal("not found")
}
if val.(string) != "0c4f7d51-af18-4475-9fdc-5f022fb8079c" {
b.Fatal("wrong value")
}
}
valStdlib = val
}
var valTyped RequestID
func BenchmarkTyped(b *testing.B) {
ctx := context.Background()
ctx = typedcontext.Set(ctx, RequestID("0c4f7d51-af18-4475-9fdc-5f022fb8079c"))
var val RequestID
var ok bool
for n := 0; n < b.N; n++ {
val, ok = typedcontext.Get[RequestID](ctx)
if !ok {
b.Fatal("not found")
}
if val != "0c4f7d51-af18-4475-9fdc-5f022fb8079c" {
b.Fatal("wrong value")
}
}
valTyped = val
}