Skip to content

Commit 8a3c805

Browse files
authored
Do not validate model identifier in certain scenarios (#67)
2 parents 9414f17 + ec447c9 commit 8a3c805

File tree

5 files changed

+349
-11
lines changed

5 files changed

+349
-11
lines changed

cmd/run/run.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/MakeNowJust/heredoc"
1717
"github.com/briandowns/spinner"
1818
"github.com/github/gh-models/internal/azuremodels"
19+
"github.com/github/gh-models/internal/modelkey"
1920
"github.com/github/gh-models/internal/sse"
2021
"github.com/github/gh-models/pkg/command"
2122
"github.com/github/gh-models/pkg/prompt"
@@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
513514
return "", errors.New(noMatchErrorMessage)
514515
}
515516

517+
parsedModel, err := modelkey.ParseModelKey(modelName)
518+
if err != nil {
519+
return "", fmt.Errorf("invalid model format: %w", err)
520+
}
521+
522+
if parsedModel.Provider == "custom" {
523+
// Skip validation for custom provider
524+
return parsedModel.String(), nil
525+
}
526+
527+
// For non-custom providers, validate the model exists
528+
expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName)
516529
foundMatch := false
517530
for _, model := range models {
518-
if model.HasName(modelName) {
531+
if model.HasName(expectedModelID) {
519532
foundMatch = true
520533
break
521534
}
@@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
525538
return "", errors.New(noMatchErrorMessage)
526539
}
527540

528-
return modelName, nil
541+
return expectedModelID, nil
529542
}
530543

531544
func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) {

cmd/run/run_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,56 @@ func TestParseTemplateVariables(t *testing.T) {
403403
})
404404
}
405405
}
406+
407+
func TestValidateModelName(t *testing.T) {
408+
tests := []struct {
409+
name string
410+
modelName string
411+
expectedModel string
412+
expectError bool
413+
}{
414+
{
415+
name: "custom provider skips validation",
416+
modelName: "custom/mycompany/custom-model",
417+
expectedModel: "custom/mycompany/custom-model",
418+
expectError: false,
419+
},
420+
{
421+
name: "azureml provider requires validation",
422+
modelName: "openai/gpt-4",
423+
expectedModel: "openai/gpt-4",
424+
expectError: false,
425+
},
426+
{
427+
name: "invalid model format",
428+
modelName: "invalid-format",
429+
expectError: true,
430+
},
431+
{
432+
name: "nonexistent azureml model",
433+
modelName: "nonexistent/model",
434+
expectError: true,
435+
},
436+
}
437+
438+
// Create a mock model for testing
439+
mockModel := &azuremodels.ModelSummary{
440+
Name: "gpt-4",
441+
Publisher: "openai",
442+
Task: "chat-completion",
443+
}
444+
models := []*azuremodels.ModelSummary{mockModel}
445+
446+
for _, tt := range tests {
447+
t.Run(tt.name, func(t *testing.T) {
448+
result, err := validateModelName(tt.modelName, models)
449+
450+
if tt.expectError {
451+
require.Error(t, err)
452+
} else {
453+
require.NoError(t, err)
454+
require.Equal(t, tt.expectedModel, result)
455+
}
456+
})
457+
}
458+
}

internal/azuremodels/model_details.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ package azuremodels
22

33
import (
44
"fmt"
5-
"strings"
5+
6+
"github.com/github/gh-models/internal/modelkey"
67
)
78

89
// ModelDetails includes detailed information about a model.
@@ -28,12 +29,5 @@ func (m *ModelDetails) ContextLimits() string {
2829

2930
// FormatIdentifier formats the model identifier based on the publisher and model name.
3031
func FormatIdentifier(publisher, name string) string {
31-
formatPart := func(s string) string {
32-
// Replace spaces with dashes and convert to lowercase
33-
result := strings.ToLower(s)
34-
result = strings.ReplaceAll(result, " ", "-")
35-
return result
36-
}
37-
38-
return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name))
32+
return modelkey.FormatIdentifier("azureml", publisher, name)
3933
}

internal/modelkey/modelkey.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package modelkey
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
type ModelKey struct {
9+
Provider string
10+
Publisher string
11+
ModelName string
12+
}
13+
14+
func ParseModelKey(modelKey string) (*ModelKey, error) {
15+
if modelKey == "" {
16+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
17+
}
18+
19+
parts := strings.Split(modelKey, "/")
20+
21+
// Check for empty parts
22+
for _, part := range parts {
23+
if part == "" {
24+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
25+
}
26+
}
27+
28+
switch len(parts) {
29+
case 2:
30+
// Format: publisher/model-name (provider defaults to "azureml")
31+
return &ModelKey{
32+
Provider: "azureml",
33+
Publisher: parts[0],
34+
ModelName: parts[1],
35+
}, nil
36+
case 3:
37+
// Format: provider/publisher/model-name
38+
return &ModelKey{
39+
Provider: parts[0],
40+
Publisher: parts[1],
41+
ModelName: parts[2],
42+
}, nil
43+
default:
44+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
45+
}
46+
}
47+
48+
// String returns the string representation of the ModelKey.
49+
func (mk *ModelKey) String() string {
50+
provider := formatPart(mk.Provider)
51+
publisher := formatPart(mk.Publisher)
52+
modelName := formatPart(mk.ModelName)
53+
54+
if provider == "azureml" {
55+
return fmt.Sprintf("%s/%s", publisher, modelName)
56+
}
57+
58+
return fmt.Sprintf("%s/%s/%s", provider, publisher, modelName)
59+
}
60+
61+
func formatPart(s string) string {
62+
s = strings.ToLower(s)
63+
s = strings.ReplaceAll(s, " ", "-")
64+
65+
return s
66+
}
67+
68+
func FormatIdentifier(provider, publisher, name string) string {
69+
mk := &ModelKey{
70+
Provider: provider,
71+
Publisher: publisher,
72+
ModelName: name,
73+
}
74+
75+
return mk.String()
76+
}

0 commit comments

Comments
 (0)