Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor away the double parsing
  • Loading branch information
sgoedecke committed Jul 21, 2025
commit dae657699f222ee3debaed0d1d9dcd0ca4d25a96
46 changes: 14 additions & 32 deletions pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ type Choice struct {
Score float64 `yaml:"score"`
}

// JsonSchema represents a JSON schema for structured responses as a JSON string
type JsonSchema string
// JsonSchema represents a JSON schema for structured responses
type JsonSchema struct {
Raw string
Parsed map[string]interface{}
}

// UnmarshalYAML implements custom YAML unmarshaling for JsonSchema
// Only supports JSON string format
Expand All @@ -84,13 +87,14 @@ func (js *JsonSchema) UnmarshalYAML(node *yaml.Node) error {
return err
}

// Validate that it's valid JSON
var temp interface{}
if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil {
// Parse and validate the JSON schema
var parsed map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
return fmt.Errorf("invalid JSON in jsonSchema: %w", err)
}

*js = JsonSchema(jsonStr)
js.Raw = jsonStr
js.Parsed = parsed
return nil
}

Expand Down Expand Up @@ -131,17 +135,11 @@ func (f *File) validateResponseFormat() error {
return fmt.Errorf("jsonSchema is required when responseFormat is 'json_schema'")
}

// Parse and validate the JSON schema
var schema map[string]interface{}
if err := json.Unmarshal([]byte(*f.JsonSchema), &schema); err != nil {
return fmt.Errorf("invalid JSON in jsonSchema: %w", err)
}

// Check for required fields
if _, ok := schema["name"]; !ok {
// Check for required fields in the already parsed schema
if _, ok := f.JsonSchema.Parsed["name"]; !ok {
return fmt.Errorf("jsonSchema must contain 'name' field")
}
if _, ok := schema["schema"]; !ok {
if _, ok := f.JsonSchema.Parsed["schema"]; !ok {
return fmt.Errorf("jsonSchema must contain 'schema' field")
}
}
Expand Down Expand Up @@ -204,7 +202,6 @@ func (f *File) BuildChatCompletionOptions(messages []azuremodels.ChatMessage) az
Stream: false,
}

// Apply model parameters
if f.ModelParameters.MaxTokens != nil {
req.MaxTokens = f.ModelParameters.MaxTokens
}
Expand All @@ -215,27 +212,12 @@ func (f *File) BuildChatCompletionOptions(messages []azuremodels.ChatMessage) az
req.TopP = f.ModelParameters.TopP
}

// Apply response format
if f.ResponseFormat != nil {
responseFormat := &azuremodels.ResponseFormat{
Type: *f.ResponseFormat,
}
if f.JsonSchema != nil {
// Parse the JSON schema string into a map
var schemaMap map[string]interface{}
if err := json.Unmarshal([]byte(*f.JsonSchema), &schemaMap); err != nil {
// This should not happen as we validate during unmarshaling
// but we'll handle it gracefully
schemaMap = map[string]interface{}{
"name": "default_schema",
"strict": true,
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
},
}
}
responseFormat.JsonSchema = &schemaMap
responseFormat.JsonSchema = &f.JsonSchema.Parsed
}
req.ResponseFormat = responseFormat
}
Expand Down
13 changes: 7 additions & 6 deletions pkg/prompt/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,8 @@
require.Equal(t, "json_schema", *promptFile.ResponseFormat)
require.NotNil(t, promptFile.JsonSchema)

// Parse the JSON schema string to verify its contents
var schema map[string]interface{}
err = json.Unmarshal([]byte(*promptFile.JsonSchema), &schema)
require.NoError(t, err)

// Verify the schema contents using the already parsed data
schema := promptFile.JsonSchema.Parsed
require.Equal(t, "describe_animal", schema["name"])
require.Equal(t, true, schema["strict"])
require.Contains(t, schema, "schema")
Expand Down Expand Up @@ -270,7 +267,11 @@
promptFile := &File{
Model: "openai/gpt-4o",
ResponseFormat: func() *string { s := "json_schema"; return &s }(),
JsonSchema: func() *JsonSchema { js := JsonSchema(jsonSchemaStr); return &js }(),
JsonSchema: func() *JsonSchema {
js := &JsonSchema{Raw: jsonSchemaStr}
json.Unmarshal([]byte(jsonSchemaStr), &js.Parsed)

Check failure on line 272 in pkg/prompt/prompt_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `json.Unmarshal` is not checked (errcheck)
return js
}(),
}

messages := []azuremodels.ChatMessage{
Expand Down
Loading