Browse Source

add capabilities

jmorganca/ggml-static
Michael Yang 2 years ago
parent
commit
a30915bde1
  1. 20
      server/images.go
  2. 8
      server/routes.go
  3. 8
      template/template_test.go

20
server/images.go

@ -34,6 +34,10 @@ import (
"github.com/ollama/ollama/version"
)
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct {
Insecure bool
Username string
@ -58,8 +62,20 @@ type Model struct {
Template *template.Template
}
func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
func (m *Model) Has(caps ...Capability) bool {
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
}
func (m *Model) String() string {

8
server/routes.go

@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}

8
template/template_test.go

@ -61,8 +61,8 @@ func TestNamed(t *testing.T) {
func TestParse(t *testing.T) {
cases := []struct {
template string
capabilities []string
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
@ -81,8 +81,8 @@ func TestParse(t *testing.T) {
}
vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars)
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}

Loading…
Cancel
Save