mirror of https://gitee.com/namelin2022/ollama
29 changed files with 300 additions and 163 deletions
@ -0,0 +1,158 @@ |
|||
package template |
|||
|
|||
import ( |
|||
"bytes" |
|||
"embed" |
|||
"encoding/json" |
|||
"errors" |
|||
"io" |
|||
"math" |
|||
"slices" |
|||
"strings" |
|||
"sync" |
|||
"text/template" |
|||
"text/template/parse" |
|||
|
|||
"github.com/agnivade/levenshtein" |
|||
"golang.org/x/exp/maps" |
|||
) |
|||
|
|||
//go:embed index.json
|
|||
var indexBytes []byte |
|||
|
|||
//go:embed *.gotmpl
|
|||
var templatesFS embed.FS |
|||
|
|||
var templatesOnce = sync.OnceValues(func() ([]*named, error) { |
|||
var templates []*named |
|||
if err := json.Unmarshal(indexBytes, &templates); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
for _, t := range templates { |
|||
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// normalize line endings
|
|||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) |
|||
} |
|||
|
|||
return templates, nil |
|||
}) |
|||
|
|||
type named struct { |
|||
Name string `json:"name"` |
|||
Template string `json:"template"` |
|||
Bytes []byte |
|||
} |
|||
|
|||
func (t named) Reader() io.Reader { |
|||
return bytes.NewReader(t.Bytes) |
|||
} |
|||
|
|||
func Named(s string) (*named, error) { |
|||
templates, err := templatesOnce() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
var template *named |
|||
score := math.MaxInt |
|||
for _, t := range templates { |
|||
if s := levenshtein.ComputeDistance(s, t.Template); s < score { |
|||
score = s |
|||
template = t |
|||
} |
|||
} |
|||
|
|||
if score < 100 { |
|||
return template, nil |
|||
} |
|||
|
|||
return nil, errors.New("no matching template found") |
|||
} |
|||
|
|||
type Template struct { |
|||
*template.Template |
|||
raw string |
|||
} |
|||
|
|||
func (t *Template) String() string { |
|||
return t.raw |
|||
} |
|||
|
|||
var DefaultTemplate, _ = Parse("{{ .Prompt }}") |
|||
|
|||
func Parse(s string) (*Template, error) { |
|||
t, err := template.New("").Option("missingkey=zero").Parse(s) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return &Template{Template: t, raw: s}, nil |
|||
} |
|||
|
|||
func (t *Template) Vars() []string { |
|||
var vars []string |
|||
for _, n := range t.Tree.Root.Nodes { |
|||
vars = append(vars, parseNode(n)...) |
|||
} |
|||
|
|||
set := make(map[string]struct{}) |
|||
for _, n := range vars { |
|||
set[strings.ToLower(n)] = struct{}{} |
|||
} |
|||
|
|||
vars = maps.Keys(set) |
|||
slices.Sort(vars) |
|||
return vars |
|||
} |
|||
|
|||
func parseNode(n parse.Node) []string { |
|||
switch n := n.(type) { |
|||
case *parse.ActionNode: |
|||
return parseNode(n.Pipe) |
|||
case *parse.IfNode: |
|||
names := parseNode(n.Pipe) |
|||
names = append(names, parseNode(n.List)...) |
|||
if n.ElseList != nil { |
|||
names = append(names, parseNode(n.ElseList)...) |
|||
} |
|||
return names |
|||
case *parse.RangeNode: |
|||
names := parseNode(n.Pipe) |
|||
names = append(names, parseNode(n.List)...) |
|||
if n.ElseList != nil { |
|||
names = append(names, parseNode(n.ElseList)...) |
|||
} |
|||
return names |
|||
case *parse.WithNode: |
|||
names := parseNode(n.Pipe) |
|||
names = append(names, parseNode(n.List)...) |
|||
if n.ElseList != nil { |
|||
names = append(names, parseNode(n.ElseList)...) |
|||
} |
|||
return names |
|||
case *parse.PipeNode: |
|||
var names []string |
|||
for _, c := range n.Cmds { |
|||
for _, a := range c.Args { |
|||
names = append(names, parseNode(a)...) |
|||
} |
|||
} |
|||
return names |
|||
case *parse.ListNode: |
|||
var names []string |
|||
for _, n := range n.Nodes { |
|||
names = append(names, parseNode(n)...) |
|||
} |
|||
|
|||
return names |
|||
case *parse.FieldNode: |
|||
return n.Ident |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
@ -0,0 +1,89 @@ |
|||
package template |
|||
|
|||
import ( |
|||
"bufio" |
|||
"bytes" |
|||
"encoding/json" |
|||
"io" |
|||
"os" |
|||
"path/filepath" |
|||
"slices" |
|||
"testing" |
|||
"text/template" |
|||
|
|||
"github.com/ollama/ollama/llm" |
|||
) |
|||
|
|||
func TestNamed(t *testing.T) { |
|||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
scanner := bufio.NewScanner(f) |
|||
for scanner.Scan() { |
|||
var ss map[string]string |
|||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
for k, v := range ss { |
|||
t.Run(k, func(t *testing.T) { |
|||
kv := llm.KV{"tokenizer.chat_template": v} |
|||
s := kv.ChatTemplate() |
|||
r, err := Named(s) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if r.Name != k { |
|||
t.Errorf("expected %q, got %q", k, r.Name) |
|||
} |
|||
|
|||
var b bytes.Buffer |
|||
if _, err := io.Copy(&b, r.Reader()); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tmpl, err := template.New(s).Parse(b.String()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if tmpl.Tree.Root.String() == "" { |
|||
t.Errorf("empty %s template", k) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestParse(t *testing.T) { |
|||
cases := []struct { |
|||
template string |
|||
capabilities []string |
|||
}{ |
|||
{"{{ .Prompt }}", []string{"prompt"}}, |
|||
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, |
|||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, |
|||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}}, |
|||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, |
|||
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, |
|||
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run("", func(t *testing.T) { |
|||
tmpl, err := Parse(tt.template) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
vars := tmpl.Vars() |
|||
if !slices.Equal(tt.capabilities, vars) { |
|||
t.Errorf("expected %v, got %v", tt.capabilities, vars) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
@ -1,70 +0,0 @@ |
|||
package templates |
|||
|
|||
import ( |
|||
"bytes" |
|||
"embed" |
|||
"encoding/json" |
|||
"errors" |
|||
"io" |
|||
"math" |
|||
"sync" |
|||
|
|||
"github.com/agnivade/levenshtein" |
|||
) |
|||
|
|||
//go:embed index.json
|
|||
var indexBytes []byte |
|||
|
|||
//go:embed *.gotmpl
|
|||
var templatesFS embed.FS |
|||
|
|||
var templatesOnce = sync.OnceValues(func() ([]*Template, error) { |
|||
var templates []*Template |
|||
if err := json.Unmarshal(indexBytes, &templates); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
for _, t := range templates { |
|||
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// normalize line endings
|
|||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) |
|||
} |
|||
|
|||
return templates, nil |
|||
}) |
|||
|
|||
type Template struct { |
|||
Name string `json:"name"` |
|||
Template string `json:"template"` |
|||
Bytes []byte |
|||
} |
|||
|
|||
func (t Template) Reader() io.Reader { |
|||
return bytes.NewReader(t.Bytes) |
|||
} |
|||
|
|||
func NamedTemplate(s string) (*Template, error) { |
|||
templates, err := templatesOnce() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
var template *Template |
|||
score := math.MaxInt |
|||
for _, t := range templates { |
|||
if s := levenshtein.ComputeDistance(s, t.Template); s < score { |
|||
score = s |
|||
template = t |
|||
} |
|||
} |
|||
|
|||
if score < 100 { |
|||
return template, nil |
|||
} |
|||
|
|||
return nil, errors.New("no matching template found") |
|||
} |
|||
@ -1,59 +0,0 @@ |
|||
package templates |
|||
|
|||
import ( |
|||
"bufio" |
|||
"bytes" |
|||
"encoding/json" |
|||
"io" |
|||
"os" |
|||
"path/filepath" |
|||
"testing" |
|||
"text/template" |
|||
|
|||
"github.com/ollama/ollama/llm" |
|||
) |
|||
|
|||
func TestKVChatTemplate(t *testing.T) { |
|||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
scanner := bufio.NewScanner(f) |
|||
for scanner.Scan() { |
|||
var ss map[string]string |
|||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
for k, v := range ss { |
|||
t.Run(k, func(t *testing.T) { |
|||
kv := llm.KV{"tokenizer.chat_template": v} |
|||
s := kv.ChatTemplate() |
|||
r, err := NamedTemplate(s) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if r.Name != k { |
|||
t.Errorf("expected %q, got %q", k, r.Name) |
|||
} |
|||
|
|||
var b bytes.Buffer |
|||
if _, err := io.Copy(&b, r.Reader()); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tmpl, err := template.New(s).Parse(b.String()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if tmpl.Tree.Root.String() == "" { |
|||
t.Errorf("empty %s template", k) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
} |
|||
Loading…
Reference in new issue