mirror of https://gitee.com/namelin2022/ollama
29 changed files with 300 additions and 163 deletions
@ -0,0 +1,158 @@ |
|||||
|
package template |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"embed" |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"io" |
||||
|
"math" |
||||
|
"slices" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"text/template" |
||||
|
"text/template/parse" |
||||
|
|
||||
|
"github.com/agnivade/levenshtein" |
||||
|
"golang.org/x/exp/maps" |
||||
|
) |
||||
|
|
||||
|
//go:embed index.json
|
||||
|
var indexBytes []byte |
||||
|
|
||||
|
//go:embed *.gotmpl
|
||||
|
var templatesFS embed.FS |
||||
|
|
||||
|
var templatesOnce = sync.OnceValues(func() ([]*named, error) { |
||||
|
var templates []*named |
||||
|
if err := json.Unmarshal(indexBytes, &templates); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
for _, t := range templates { |
||||
|
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
// normalize line endings
|
||||
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) |
||||
|
} |
||||
|
|
||||
|
return templates, nil |
||||
|
}) |
||||
|
|
||||
|
type named struct { |
||||
|
Name string `json:"name"` |
||||
|
Template string `json:"template"` |
||||
|
Bytes []byte |
||||
|
} |
||||
|
|
||||
|
func (t named) Reader() io.Reader { |
||||
|
return bytes.NewReader(t.Bytes) |
||||
|
} |
||||
|
|
||||
|
func Named(s string) (*named, error) { |
||||
|
templates, err := templatesOnce() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
var template *named |
||||
|
score := math.MaxInt |
||||
|
for _, t := range templates { |
||||
|
if s := levenshtein.ComputeDistance(s, t.Template); s < score { |
||||
|
score = s |
||||
|
template = t |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if score < 100 { |
||||
|
return template, nil |
||||
|
} |
||||
|
|
||||
|
return nil, errors.New("no matching template found") |
||||
|
} |
||||
|
|
||||
|
type Template struct { |
||||
|
*template.Template |
||||
|
raw string |
||||
|
} |
||||
|
|
||||
|
func (t *Template) String() string { |
||||
|
return t.raw |
||||
|
} |
||||
|
|
||||
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}") |
||||
|
|
||||
|
func Parse(s string) (*Template, error) { |
||||
|
t, err := template.New("").Option("missingkey=zero").Parse(s) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
return &Template{Template: t, raw: s}, nil |
||||
|
} |
||||
|
|
||||
|
func (t *Template) Vars() []string { |
||||
|
var vars []string |
||||
|
for _, n := range t.Tree.Root.Nodes { |
||||
|
vars = append(vars, parseNode(n)...) |
||||
|
} |
||||
|
|
||||
|
set := make(map[string]struct{}) |
||||
|
for _, n := range vars { |
||||
|
set[strings.ToLower(n)] = struct{}{} |
||||
|
} |
||||
|
|
||||
|
vars = maps.Keys(set) |
||||
|
slices.Sort(vars) |
||||
|
return vars |
||||
|
} |
||||
|
|
||||
|
func parseNode(n parse.Node) []string { |
||||
|
switch n := n.(type) { |
||||
|
case *parse.ActionNode: |
||||
|
return parseNode(n.Pipe) |
||||
|
case *parse.IfNode: |
||||
|
names := parseNode(n.Pipe) |
||||
|
names = append(names, parseNode(n.List)...) |
||||
|
if n.ElseList != nil { |
||||
|
names = append(names, parseNode(n.ElseList)...) |
||||
|
} |
||||
|
return names |
||||
|
case *parse.RangeNode: |
||||
|
names := parseNode(n.Pipe) |
||||
|
names = append(names, parseNode(n.List)...) |
||||
|
if n.ElseList != nil { |
||||
|
names = append(names, parseNode(n.ElseList)...) |
||||
|
} |
||||
|
return names |
||||
|
case *parse.WithNode: |
||||
|
names := parseNode(n.Pipe) |
||||
|
names = append(names, parseNode(n.List)...) |
||||
|
if n.ElseList != nil { |
||||
|
names = append(names, parseNode(n.ElseList)...) |
||||
|
} |
||||
|
return names |
||||
|
case *parse.PipeNode: |
||||
|
var names []string |
||||
|
for _, c := range n.Cmds { |
||||
|
for _, a := range c.Args { |
||||
|
names = append(names, parseNode(a)...) |
||||
|
} |
||||
|
} |
||||
|
return names |
||||
|
case *parse.ListNode: |
||||
|
var names []string |
||||
|
for _, n := range n.Nodes { |
||||
|
names = append(names, parseNode(n)...) |
||||
|
} |
||||
|
|
||||
|
return names |
||||
|
case *parse.FieldNode: |
||||
|
return n.Ident |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
@ -0,0 +1,89 @@ |
|||||
|
package template |
||||
|
|
||||
|
import ( |
||||
|
"bufio" |
||||
|
"bytes" |
||||
|
"encoding/json" |
||||
|
"io" |
||||
|
"os" |
||||
|
"path/filepath" |
||||
|
"slices" |
||||
|
"testing" |
||||
|
"text/template" |
||||
|
|
||||
|
"github.com/ollama/ollama/llm" |
||||
|
) |
||||
|
|
||||
|
func TestNamed(t *testing.T) { |
||||
|
f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
defer f.Close() |
||||
|
|
||||
|
scanner := bufio.NewScanner(f) |
||||
|
for scanner.Scan() { |
||||
|
var ss map[string]string |
||||
|
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
for k, v := range ss { |
||||
|
t.Run(k, func(t *testing.T) { |
||||
|
kv := llm.KV{"tokenizer.chat_template": v} |
||||
|
s := kv.ChatTemplate() |
||||
|
r, err := Named(s) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
if r.Name != k { |
||||
|
t.Errorf("expected %q, got %q", k, r.Name) |
||||
|
} |
||||
|
|
||||
|
var b bytes.Buffer |
||||
|
if _, err := io.Copy(&b, r.Reader()); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
tmpl, err := template.New(s).Parse(b.String()) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
if tmpl.Tree.Root.String() == "" { |
||||
|
t.Errorf("empty %s template", k) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParse(t *testing.T) { |
||||
|
cases := []struct { |
||||
|
template string |
||||
|
capabilities []string |
||||
|
}{ |
||||
|
{"{{ .Prompt }}", []string{"prompt"}}, |
||||
|
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, |
||||
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, |
||||
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}}, |
||||
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, |
||||
|
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, |
||||
|
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}}, |
||||
|
} |
||||
|
|
||||
|
for _, tt := range cases { |
||||
|
t.Run("", func(t *testing.T) { |
||||
|
tmpl, err := Parse(tt.template) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
vars := tmpl.Vars() |
||||
|
if !slices.Equal(tt.capabilities, vars) { |
||||
|
t.Errorf("expected %v, got %v", tt.capabilities, vars) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
@ -1,70 +0,0 @@ |
|||||
package templates |
|
||||
|
|
||||
import ( |
|
||||
"bytes" |
|
||||
"embed" |
|
||||
"encoding/json" |
|
||||
"errors" |
|
||||
"io" |
|
||||
"math" |
|
||||
"sync" |
|
||||
|
|
||||
"github.com/agnivade/levenshtein" |
|
||||
) |
|
||||
|
|
||||
//go:embed index.json
|
|
||||
var indexBytes []byte |
|
||||
|
|
||||
//go:embed *.gotmpl
|
|
||||
var templatesFS embed.FS |
|
||||
|
|
||||
var templatesOnce = sync.OnceValues(func() ([]*Template, error) { |
|
||||
var templates []*Template |
|
||||
if err := json.Unmarshal(indexBytes, &templates); err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
for _, t := range templates { |
|
||||
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
// normalize line endings
|
|
||||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) |
|
||||
} |
|
||||
|
|
||||
return templates, nil |
|
||||
}) |
|
||||
|
|
||||
type Template struct { |
|
||||
Name string `json:"name"` |
|
||||
Template string `json:"template"` |
|
||||
Bytes []byte |
|
||||
} |
|
||||
|
|
||||
func (t Template) Reader() io.Reader { |
|
||||
return bytes.NewReader(t.Bytes) |
|
||||
} |
|
||||
|
|
||||
func NamedTemplate(s string) (*Template, error) { |
|
||||
templates, err := templatesOnce() |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
var template *Template |
|
||||
score := math.MaxInt |
|
||||
for _, t := range templates { |
|
||||
if s := levenshtein.ComputeDistance(s, t.Template); s < score { |
|
||||
score = s |
|
||||
template = t |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
if score < 100 { |
|
||||
return template, nil |
|
||||
} |
|
||||
|
|
||||
return nil, errors.New("no matching template found") |
|
||||
} |
|
||||
@ -1,59 +0,0 @@ |
|||||
package templates |
|
||||
|
|
||||
import ( |
|
||||
"bufio" |
|
||||
"bytes" |
|
||||
"encoding/json" |
|
||||
"io" |
|
||||
"os" |
|
||||
"path/filepath" |
|
||||
"testing" |
|
||||
"text/template" |
|
||||
|
|
||||
"github.com/ollama/ollama/llm" |
|
||||
) |
|
||||
|
|
||||
func TestKVChatTemplate(t *testing.T) { |
|
||||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) |
|
||||
if err != nil { |
|
||||
t.Fatal(err) |
|
||||
} |
|
||||
defer f.Close() |
|
||||
|
|
||||
scanner := bufio.NewScanner(f) |
|
||||
for scanner.Scan() { |
|
||||
var ss map[string]string |
|
||||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { |
|
||||
t.Fatal(err) |
|
||||
} |
|
||||
|
|
||||
for k, v := range ss { |
|
||||
t.Run(k, func(t *testing.T) { |
|
||||
kv := llm.KV{"tokenizer.chat_template": v} |
|
||||
s := kv.ChatTemplate() |
|
||||
r, err := NamedTemplate(s) |
|
||||
if err != nil { |
|
||||
t.Fatal(err) |
|
||||
} |
|
||||
|
|
||||
if r.Name != k { |
|
||||
t.Errorf("expected %q, got %q", k, r.Name) |
|
||||
} |
|
||||
|
|
||||
var b bytes.Buffer |
|
||||
if _, err := io.Copy(&b, r.Reader()); err != nil { |
|
||||
t.Fatal(err) |
|
||||
} |
|
||||
|
|
||||
tmpl, err := template.New(s).Parse(b.String()) |
|
||||
if err != nil { |
|
||||
t.Fatal(err) |
|
||||
} |
|
||||
|
|
||||
if tmpl.Tree.Root.String() == "" { |
|
||||
t.Errorf("empty %s template", k) |
|
||||
} |
|
||||
}) |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
Loading…
Reference in new issue