-
Notifications
You must be signed in to change notification settings - Fork 1
/
typechat_test.go
92 lines (84 loc) · 2.12 KB
/
typechat_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package typechat
import (
"context"
"encoding/json"
"testing"
)
type mockModelClient struct {
response string
err error
}
func (m mockModelClient) Do(ctx context.Context, prompt []Message) (response string, err error) {
return m.response, m.err
}
func TestTypeChat(t *testing.T) {
type Result struct {
Sentiment string `json:"sentiment"`
}
t.Run("it should generate the prompt and return the result", func(t *testing.T) {
ctx := context.Background()
m := mockModelClient{
response: `{"sentiment": "positive"}`,
}
p := NewPrompt[Result](m, "That game was awesome!")
result, err := p.Execute(ctx)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if result.Sentiment != "positive" {
t.Errorf("Expected positive, got %v", result.Sentiment)
}
})
t.Run("it should configure retries", func(t *testing.T) {
p := NewPrompt[Result](nil, "", PromptRetries[Result](5))
if p.retries != 5 {
t.Errorf("Expected 5 retries, got %v", p.retries)
}
})
t.Run("it should set a default number of retries", func(t *testing.T) {
p := NewPrompt[Result](nil, "")
if p.retries != 1 {
t.Errorf("Expected 3 retries, got %v", p.retries)
}
})
t.Run("it should accept an API interface and return a program", func(t *testing.T) {
ctx := context.Background()
type API interface {
Step1(name string) (string, error)
Step2(value int) error
}
program := Program{
Steps: []FunctionCall{
{
Name: "Step1",
Args: []any{"name"},
},
{
Name: "Step2",
Args: []any{2},
},
},
}
b, err := json.Marshal(program)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
m := mockModelClient{
response: string(b),
}
p := NewPrompt[API](m, "")
result, err := p.CreateProgram(ctx)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(result.Steps) != 2 {
t.Errorf("Expected 2 steps, got %v", len(result.Steps))
}
if result.Steps[0].Name != "Step1" {
t.Errorf("Expected Step1, got %v", result.Steps[0].Name)
}
if result.Steps[1].Name != "Step2" {
t.Errorf("Expected Step2, got %v", result.Steps[1].Name)
}
})
}