-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtypechat.go
150 lines (121 loc) · 3.18 KB
/
typechat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package typechat
import (
"context"
"encoding/json"
"fmt"
)
type Role struct {
name string
}
func (r Role) String() string {
return r.name
}
var (
RoleUser = Role{name: "user"}
RoleSystem = Role{name: "system"}
RoleAssistant = Role{name: "assistant"}
)
type Message struct {
Content string
Role Role
}
func newSystemMessage(content string) Message {
return Message{
Content: content,
Role: RoleSystem,
}
}
func newUserMessage(content string) Message {
return Message{
Content: content,
Role: RoleUser,
}
}
func newAssistantMessage(content string) Message {
return Message{
Content: content,
Role: RoleAssistant,
}
}
type client interface {
Do(ctx context.Context, prompt []Message) (response string, err error)
}
// Prompt is a generic typechat prompt.
type Prompt[T any] struct {
model client
prompt string
retries int
}
type opt[T any] func(*Prompt[T])
// PromptRetries sets the number of times to retry parsing errors.
func PromptRetries[T any](retries int) opt[T] {
return func(t *Prompt[T]) {
t.retries = retries
}
}
// NewPrompt creates a new Prompt[T] with the given modelClient, prompt and options.
func NewPrompt[T any](model client, prompt string, opts ...opt[T]) *Prompt[T] {
t := &Prompt[T]{
model: model,
prompt: prompt,
}
for _, opt := range opts {
opt(t)
}
if t.retries <= 0 {
t.retries = 1
}
return t
}
// Execute executes the user prompt and parses the result into the given structure. Parsing errors are retried up to
// Prompt.retries times.
func (p *Prompt[T]) Execute(ctx context.Context) (T, error) {
var result T
b, err := newBuilder[T](promptUserRequest, p.prompt)
if err != nil {
return result, fmt.Errorf("failed to create prompt builder: %w", err)
}
if err := p.exec(ctx, b, &result); err != nil {
return result, fmt.Errorf("failed to execute prompt: %w", err)
}
return result, nil
}
// CreateProgram executes the prompt with the provided API and parses the result into a typechat.Program to be used
// by callers. Refer to the Program struct for structure. Steps will refer to methods provided in the API interface.
// Parsing errors are retried up to Prompt.retries times.
func (p *Prompt[T]) CreateProgram(ctx context.Context) (Program, error) {
var program Program
b, err := newBuilder[T](promptProgram, p.prompt)
if err != nil {
return program, fmt.Errorf("failed to create prompt builder: %w", err)
}
if err := p.exec(ctx, b, &program); err != nil {
return program, fmt.Errorf("failed to execute prompt: %w", err)
}
return program, nil
}
func (p *Prompt[T]) exec(ctx context.Context, b *builder[T], output any) error {
prompt, err := b.prompt()
if err != nil {
return fmt.Errorf("failed to build prompt: %w", err)
}
var failedParsing bool
for i := 0; i < p.retries; i++ {
resp, err := p.model.Do(ctx, prompt)
if err != nil {
return err
}
if err := json.Unmarshal([]byte(resp), output); err != nil {
prompt, err = b.repair(resp, err)
if err != nil {
return fmt.Errorf("failed to repair prompt: %w", err)
}
failedParsing = true
continue
}
}
if failedParsing {
return fmt.Errorf("failed to parse prompt response with %d retries", p.retries)
}
return nil
}