mirror of https://gitee.com/namelin2022/ollama
committed by
GitHub
27 changed files with 1868 additions and 340 deletions
@ -1,179 +0,0 @@ |
|||
package server |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/json" |
|||
"fmt" |
|||
"os" |
|||
"path/filepath" |
|||
"testing" |
|||
|
|||
"github.com/google/go-cmp/cmp" |
|||
|
|||
"github.com/ollama/ollama/api" |
|||
"github.com/ollama/ollama/template" |
|||
) |
|||
|
|||
func readFile(t *testing.T, base, name string) *bytes.Buffer { |
|||
t.Helper() |
|||
|
|||
bts, err := os.ReadFile(filepath.Join(base, name)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
return bytes.NewBuffer(bts) |
|||
} |
|||
|
|||
func TestExecuteWithTools(t *testing.T) { |
|||
p := filepath.Join("testdata", "tools") |
|||
cases := []struct { |
|||
model string |
|||
output string |
|||
ok bool |
|||
}{ |
|||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, |
|||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] |
|||
|
|||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, |
|||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, |
|||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: |
|||
|
|||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, |
|||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, |
|||
{"command-r-plus", "Action: ```json" + ` |
|||
[ |
|||
{ |
|||
"tool_name": "get_current_weather", |
|||
"parameters": { |
|||
"format": "fahrenheit", |
|||
"location": "San Francisco, CA" |
|||
} |
|||
}, |
|||
{ |
|||
"tool_name": "get_current_weather", |
|||
"parameters": { |
|||
"format": "celsius", |
|||
"location": "Toronto, Canada" |
|||
} |
|||
} |
|||
] |
|||
` + "```", true}, |
|||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, |
|||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, |
|||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, |
|||
{"llama3-groq-tool-use", `<tool_call> |
|||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} |
|||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} |
|||
</tool_call>`, true}, |
|||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, |
|||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true}, |
|||
} |
|||
|
|||
var tools []api.Tool |
|||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
var messages []api.Message |
|||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
calls := []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "get_current_weather", |
|||
Arguments: api.ToolCallFunctionArguments{ |
|||
"format": "fahrenheit", |
|||
"location": "San Francisco, CA", |
|||
}, |
|||
}, |
|||
}, |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "get_current_weather", |
|||
Arguments: api.ToolCallFunctionArguments{ |
|||
"format": "celsius", |
|||
"location": "Toronto, Canada", |
|||
}, |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.model, func(t *testing.T) { |
|||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
t.Run("template", func(t *testing.T) { |
|||
var actual bytes.Buffer |
|||
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { |
|||
t.Errorf("mismatch (-got +want):\n%s", diff) |
|||
} |
|||
}) |
|||
|
|||
t.Run("parse", func(t *testing.T) { |
|||
m := &Model{Template: tmpl} |
|||
actual, ok := m.parseToolCalls(tt.output) |
|||
if ok != tt.ok { |
|||
t.Fatalf("expected %t, got %t", tt.ok, ok) |
|||
} |
|||
|
|||
if tt.ok { |
|||
if diff := cmp.Diff(actual, calls); diff != "" { |
|||
t.Errorf("mismatch (-got +want):\n%s", diff) |
|||
} |
|||
} |
|||
}) |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestParseObjects(t *testing.T) { |
|||
tests := []struct { |
|||
input string |
|||
want []map[string]any |
|||
}{ |
|||
{ |
|||
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
want: []map[string]any{ |
|||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, |
|||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, |
|||
}, |
|||
}, |
|||
{ |
|||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`, |
|||
want: []map[string]any{ |
|||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, |
|||
}, |
|||
}, |
|||
{ |
|||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`, |
|||
want: []map[string]any{ |
|||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, |
|||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, |
|||
}, |
|||
}, |
|||
{ |
|||
input: `{"name": "get_current_weather", "arguments": `, |
|||
want: nil, |
|||
}, |
|||
} |
|||
|
|||
for _, tc := range tests { |
|||
t.Run(tc.input, func(t *testing.T) { |
|||
got := parseObjects(tc.input) |
|||
|
|||
if diff := cmp.Diff(got, tc.want); diff != "" { |
|||
t.Errorf("mismatch (-got +want):\n%s", diff) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
@ -0,0 +1,44 @@ |
|||
<|start_header_id|>system<|end_header_id|> |
|||
|
|||
Cutting Knowledge Date: December 2023 |
|||
|
|||
{{ if .System }}{{ .System }} |
|||
{{- end }} |
|||
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. |
|||
|
|||
You are a helpful assistant with tool calling capabilities. |
|||
{{- end }}<|eot_id|> |
|||
{{- range $i, $_ := .Messages }} |
|||
{{- $last := eq (len (slice $.Messages $i)) 1 }} |
|||
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> |
|||
{{- if and $.Tools $last }} |
|||
|
|||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. |
|||
|
|||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. |
|||
|
|||
{{ range $.Tools }} |
|||
{{- . }} |
|||
{{ end }} |
|||
{{ .Content }}<|eot_id|> |
|||
{{- else }} |
|||
|
|||
{{ .Content }}<|eot_id|> |
|||
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> |
|||
|
|||
{{ end }} |
|||
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> |
|||
{{- if .ToolCalls }} |
|||
{{ range .ToolCalls }} |
|||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} |
|||
{{- else }} |
|||
|
|||
{{ .Content }} |
|||
{{- end }}{{ if not $last }}<|eot_id|>{{ end }} |
|||
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> |
|||
|
|||
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> |
|||
|
|||
{{ end }} |
|||
{{- end }} |
|||
{{- end }} |
|||
@ -0,0 +1,24 @@ |
|||
<|start_header_id|>system<|end_header_id|> |
|||
|
|||
Cutting Knowledge Date: December 2023 |
|||
|
|||
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. |
|||
|
|||
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> |
|||
|
|||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|||
|
|||
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> |
|||
|
|||
22<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|||
|
|||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> |
|||
|
|||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. |
|||
|
|||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. |
|||
|
|||
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} |
|||
|
|||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|||
|
|||
@ -0,0 +1,51 @@ |
|||
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> |
|||
{{- else if .Messages }} |
|||
{{- if or .System .Tools }}<|im_start|>system |
|||
{{- if .System }} |
|||
{{ .System }} |
|||
{{- end }} |
|||
{{- if .Tools }} |
|||
|
|||
# Tools |
|||
|
|||
You may call one or more functions to assist with the user query. |
|||
|
|||
You are provided with function signatures within <tools></tools> XML tags: |
|||
<tools> |
|||
{{- range .Tools }} |
|||
{"type": "function", "function": {{ .Function }}} |
|||
{{- end }} |
|||
</tools> |
|||
|
|||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|||
<tool_call> |
|||
{"name": <function-name>, "arguments": <args-json-object>} |
|||
</tool_call> |
|||
{{- end }}<|im_end|> |
|||
{{ end }} |
|||
{{- range $i, $_ := .Messages }} |
|||
{{- $last := eq (len (slice $.Messages $i)) 1 -}} |
|||
{{- if eq .Role "user" }}<|im_start|>user |
|||
{{ .Content }}<|im_end|> |
|||
{{ else if eq .Role "assistant" }}<|im_start|>assistant |
|||
{{ if .Content }}{{ .Content }} |
|||
{{- else if .ToolCalls }}<tool_call> |
|||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} |
|||
{{ end }}</tool_call> |
|||
{{- end }}{{ if not $last }}<|im_end|> |
|||
{{ end }} |
|||
{{- else if eq .Role "tool" }}<|im_start|>user |
|||
<tool_response> |
|||
{{ .Content }} |
|||
</tool_response><|im_end|> |
|||
{{ end }} |
|||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant |
|||
{{ end }} |
|||
{{- end }} |
|||
{{- else }} |
|||
{{- if .System }}<|im_start|>system |
|||
{{ .System }}<|im_end|> |
|||
{{ end }}{{ if .Prompt }}<|im_start|>user |
|||
{{ .Prompt }}<|im_end|> |
|||
{{ end }}<|im_start|>assistant |
|||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} |
|||
@ -0,0 +1,31 @@ |
|||
<|im_start|>system |
|||
You are a knowledgeable assistant. You can answer questions and perform tasks. |
|||
|
|||
# Tools |
|||
|
|||
You may call one or more functions to assist with the user query. |
|||
|
|||
You are provided with function signatures within <tools></tools> XML tags: |
|||
<tools> |
|||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} |
|||
</tools> |
|||
|
|||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|||
<tool_call> |
|||
{"name": <function-name>, "arguments": <args-json-object>} |
|||
</tool_call><|im_end|> |
|||
<|im_start|>user |
|||
What's the weather like today in Paris?<|im_end|> |
|||
<|im_start|>assistant |
|||
<tool_call> |
|||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} |
|||
</tool_call><|im_end|> |
|||
<|im_start|>user |
|||
<tool_response> |
|||
22 |
|||
</tool_response><|im_end|> |
|||
<|im_start|>assistant |
|||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> |
|||
<|im_start|>user |
|||
What's the weather like today in San Francisco and Toronto?<|im_end|> |
|||
<|im_start|>assistant |
|||
@ -0,0 +1,50 @@ |
|||
{{- if .Messages }} |
|||
{{- if or .System .Tools }}<|im_start|>system |
|||
{{- if .System }} |
|||
{{ .System }} |
|||
{{- end }} |
|||
{{- if .Tools }} |
|||
|
|||
# Tools |
|||
|
|||
You may call one or more functions to assist with the user query. |
|||
|
|||
You are provided with function signatures within <tools></tools> XML tags: |
|||
<tools> |
|||
{{- range .Tools }} |
|||
{"type": "function", "function": {{ .Function }}} |
|||
{{- end }} |
|||
</tools> |
|||
|
|||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|||
<tool_call> |
|||
{"name": <function-name>, "arguments": <args-json-object>} |
|||
</tool_call> |
|||
{{- end }}<|im_end|> |
|||
{{ end }} |
|||
{{- range $i, $_ := .Messages }} |
|||
{{- $last := eq (len (slice $.Messages $i)) 1 -}} |
|||
{{- if eq .Role "user" }}<|im_start|>user |
|||
{{ .Content }}<|im_end|> |
|||
{{ else if eq .Role "assistant" }}<|im_start|>assistant |
|||
{{ if .Content }}{{ .Content }} |
|||
{{- else if .ToolCalls }}<tool_call> |
|||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} |
|||
{{ end }}</tool_call> |
|||
{{- end }}{{ if not $last }}<|im_end|> |
|||
{{ end }} |
|||
{{- else if eq .Role "tool" }}<|im_start|>user |
|||
<tool_response> |
|||
{{ .Content }} |
|||
</tool_response><|im_end|> |
|||
{{ end }} |
|||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant |
|||
{{ end }} |
|||
{{- end }} |
|||
{{- else }} |
|||
{{- if .System }}<|im_start|>system |
|||
{{ .System }}<|im_end|> |
|||
{{ end }}{{ if .Prompt }}<|im_start|>user |
|||
{{ .Prompt }}<|im_end|> |
|||
{{ end }}<|im_start|>assistant |
|||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} |
|||
@ -0,0 +1,31 @@ |
|||
<|im_start|>system |
|||
You are a knowledgeable assistant. You can answer questions and perform tasks. |
|||
|
|||
# Tools |
|||
|
|||
You may call one or more functions to assist with the user query. |
|||
|
|||
You are provided with function signatures within <tools></tools> XML tags: |
|||
<tools> |
|||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} |
|||
</tools> |
|||
|
|||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|||
<tool_call> |
|||
{"name": <function-name>, "arguments": <args-json-object>} |
|||
</tool_call><|im_end|> |
|||
<|im_start|>user |
|||
What's the weather like today in Paris?<|im_end|> |
|||
<|im_start|>assistant |
|||
<tool_call> |
|||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} |
|||
</tool_call><|im_end|> |
|||
<|im_start|>user |
|||
<tool_response> |
|||
22 |
|||
</tool_response><|im_end|> |
|||
<|im_start|>assistant |
|||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> |
|||
<|im_start|>user |
|||
What's the weather like today in San Francisco and Toronto?<|im_end|> |
|||
<|im_start|>assistant |
|||
@ -0,0 +1,271 @@ |
|||
package tools |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"log/slog" |
|||
"strings" |
|||
gotmpl "text/template" |
|||
|
|||
"github.com/ollama/ollama/api" |
|||
"github.com/ollama/ollama/template" |
|||
) |
|||
|
|||
var ( |
|||
errInvalidToolCall = errors.New("invalid tool call format") |
|||
errAccumulateMore = errors.New("need to accumulate more content") |
|||
) |
|||
|
|||
type Parser struct { |
|||
parseLeadingJSON bool |
|||
prefix string |
|||
prefixFound bool |
|||
tmpl gotmpl.Template |
|||
sb strings.Builder |
|||
index int |
|||
name string |
|||
arguments string |
|||
done bool |
|||
} |
|||
|
|||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
|||
//
|
|||
// Parameters:
|
|||
// - s: The string to parse
|
|||
// - name: The field name from template that identifies the tool call name
|
|||
// - arguments: The field name from template that identifies the tool call arguments
|
|||
//
|
|||
// Returns:
|
|||
// - []api.ToolCall: The parsed tool calls if successful
|
|||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
|||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { |
|||
// Check for balanced braces before attempting to parse
|
|||
braceCount := 0 |
|||
squareCount := 0 |
|||
startIndex := -1 |
|||
var rawToolCalls []string |
|||
s = strings.TrimSpace(s) |
|||
|
|||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
|||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") |
|||
for i, c := range s { |
|||
switch c { |
|||
case '{': |
|||
braceCount++ |
|||
if startIndex == -1 { |
|||
startIndex = i |
|||
} |
|||
case '}': |
|||
braceCount-- |
|||
if braceCount == 0 { |
|||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) |
|||
startIndex = -1 |
|||
} |
|||
case '[': |
|||
if trackSquareBrackets { |
|||
squareCount++ |
|||
} |
|||
case ']': |
|||
if trackSquareBrackets { |
|||
squareCount-- |
|||
} |
|||
} |
|||
|
|||
// Negative means we have an extra closing brace/bracket
|
|||
if braceCount < 0 || squareCount < 0 { |
|||
return nil, errInvalidToolCall |
|||
} |
|||
} |
|||
|
|||
// If braces/brackets aren't balanced, need more input
|
|||
if braceCount > 0 || squareCount > 0 { |
|||
return nil, errAccumulateMore |
|||
} |
|||
|
|||
t := strings.TrimSpace(s) |
|||
if len(t) == 0 { |
|||
return nil, errAccumulateMore |
|||
} |
|||
// If the input is a single square bracket, it's not a valid tool call
|
|||
if t[0] == '[' && len(t) == 1 { |
|||
return nil, errAccumulateMore |
|||
} |
|||
|
|||
// Attempt full unmarshal of the JSON
|
|||
var toolCalls []api.ToolCall |
|||
for _, rawToolCall := range rawToolCalls { |
|||
var resp map[string]any |
|||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { |
|||
continue |
|||
} |
|||
|
|||
// Collect nested objects that could contain tool calls
|
|||
objs := collect(resp) |
|||
if len(objs) == 0 { |
|||
continue |
|||
} |
|||
|
|||
// Extract tool calls from objects
|
|||
for _, kv := range objs { |
|||
n, nok := kv[name].(string) |
|||
a, aok := kv[arguments].(map[string]any) |
|||
if nok && aok { |
|||
toolCalls = append(toolCalls, api.ToolCall{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: n, |
|||
Arguments: a, |
|||
}, |
|||
}) |
|||
} else { |
|||
slog.Debug("No valid tool call found in object.", "object", kv) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Valid JSON, no tool calls found
|
|||
if len(toolCalls) == 0 { |
|||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) |
|||
return nil, errInvalidToolCall |
|||
} |
|||
|
|||
return toolCalls, nil |
|||
} |
|||
|
|||
// checkPrefix processes a string to find and handle a prefix pattern.
|
|||
//
|
|||
// Returns:
|
|||
// - The processed string with prefix removed if found
|
|||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
|||
func (p *Parser) checkPrefix(s string) (string, error) { |
|||
original := s |
|||
if strings.ContainsRune(s, '\n') { |
|||
s = strings.ReplaceAll(s, "\n", " ") |
|||
} |
|||
|
|||
if s == "" || p.prefix == "" { |
|||
return s, nil |
|||
} |
|||
|
|||
// Check for prefix at start of string
|
|||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { |
|||
// Found prefix at start - accumulate for potential tool
|
|||
p.prefixFound = true |
|||
return cut, nil |
|||
} |
|||
|
|||
// Check if prefix overlaps end of string
|
|||
if idx := suffixOverlap(s, p.prefix); idx != -1 { |
|||
// Return everything except overlapping portion
|
|||
p.sb.Reset() |
|||
p.sb.WriteString(s[idx:]) |
|||
return original[:idx], errAccumulateMore |
|||
} |
|||
|
|||
// Check if prefix appears in middle of string
|
|||
if idx := strings.Index(s, p.prefix); idx != -1 { |
|||
// Save remainder starting at prefix for next pass
|
|||
p.sb.Reset() |
|||
p.sb.WriteString(strings.TrimSpace(s[idx:])) |
|||
// Return everything before prefix
|
|||
return original[:idx], errAccumulateMore |
|||
} |
|||
|
|||
// No partial prefix found
|
|||
return s, nil |
|||
} |
|||
|
|||
// Add processes a string input to parse tool calls and content.
|
|||
// It handles prefix detection and JSON parsing to extract tool calls.
|
|||
//
|
|||
// Returns:
|
|||
// - tools: Any parsed tool calls
|
|||
// - content: Non-tool call content
|
|||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { |
|||
if strings.TrimSpace(s) == "" { |
|||
return nil, s |
|||
} |
|||
if p.done { |
|||
if p.index == 0 { |
|||
// Return original string if no tool calls found at start
|
|||
return nil, s |
|||
} |
|||
// Return empty if no tool calls found after start
|
|||
return nil, "" |
|||
} |
|||
p.sb.WriteString(s) |
|||
s = p.sb.String() |
|||
|
|||
// Check for prefix pattern in input
|
|||
s, err := p.checkPrefix(s) |
|||
if err != nil { |
|||
// Need more input to complete prefix
|
|||
return nil, s |
|||
} |
|||
|
|||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
|||
if !p.parseLeadingJSON && !p.prefixFound { |
|||
p.sb.Reset() |
|||
return nil, s |
|||
} |
|||
|
|||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) |
|||
if err != nil { |
|||
if errors.Is(err, errAccumulateMore) { |
|||
return nil, "" |
|||
} |
|||
p.sb.Reset() |
|||
// Do not try parsing leading JSON if JSON not found
|
|||
p.parseLeadingJSON = false |
|||
if p.prefix == "" { |
|||
p.done = true |
|||
} |
|||
if p.index != 0 && p.prefix == "" { |
|||
return nil, "" |
|||
} |
|||
if p.prefixFound { |
|||
// Drop tokens since prefix was found
|
|||
return nil, "" |
|||
} |
|||
return nil, s |
|||
} |
|||
|
|||
for _, tc := range toolCalls { |
|||
tc.Function.Index = p.index |
|||
p.index++ |
|||
} |
|||
|
|||
p.sb.Reset() |
|||
return toolCalls, "" |
|||
} |
|||
|
|||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
|||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
|||
//
|
|||
// Returns an error if the template does not contain valid tool call formatting.
|
|||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { |
|||
parsed, err := template.Parse(templateToProcess.Root.String()) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
tt, err := toolTemplate(parsed) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
tp := toolPrefix(templateToProcess) |
|||
|
|||
name, arguments, err := extractToolArgs(tt) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return &Parser{ |
|||
tmpl: *tt, |
|||
sb: strings.Builder{}, |
|||
prefix: tp, |
|||
parseLeadingJSON: true, |
|||
name: name, |
|||
arguments: arguments, |
|||
}, nil |
|||
} |
|||
@ -0,0 +1,644 @@ |
|||
package tools |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/json" |
|||
"fmt" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
"testing" |
|||
|
|||
"github.com/google/go-cmp/cmp" |
|||
|
|||
"github.com/ollama/ollama/api" |
|||
"github.com/ollama/ollama/template" |
|||
) |
|||
|
|||
func readFile(t *testing.T, base, name string) *bytes.Buffer { |
|||
t.Helper() |
|||
|
|||
bts, err := os.ReadFile(filepath.Join(base, name)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
return bytes.NewBuffer(bts) |
|||
} |
|||
|
|||
func TestParseJSONToolCalls(t *testing.T) { |
|||
tests := []struct { |
|||
name string |
|||
input string |
|||
nameField string |
|||
argsField string |
|||
wantToolCalls []api.ToolCall |
|||
wantErr error |
|||
prefix string |
|||
}{ |
|||
{ |
|||
name: "valid single tool call", |
|||
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "test_tool", |
|||
Arguments: map[string]any{ |
|||
"arg1": "value1", |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "incomplete JSON", |
|||
input: `{"name": "test_tool", "arguments": {"arg1": `, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: nil, |
|||
wantErr: errAccumulateMore, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "invalid JSON", |
|||
input: `not json at all`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: nil, |
|||
wantErr: errInvalidToolCall, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "missing required fields", |
|||
input: `{"other": "field"}`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: nil, |
|||
wantErr: errInvalidToolCall, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "multiple tool calls in array", |
|||
input: `[ |
|||
{"name": "tool1", "arguments": {"arg1": 1}}, |
|||
{"name": "tool2", "arguments": {"arg2": "value"}} |
|||
]`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool1", |
|||
Arguments: map[string]any{ |
|||
"arg1": float64(1), |
|||
}, |
|||
}, |
|||
}, |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool2", |
|||
Arguments: map[string]any{ |
|||
"arg2": "value", |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "multiple tool calls without array", |
|||
input: ` |
|||
{"name": "tool1", "arguments": {"arg1": 1}}, |
|||
{"name": "tool2", "arguments": {"arg2": "value"}} |
|||
`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool1", |
|||
Arguments: map[string]any{ |
|||
"arg1": float64(1), |
|||
}, |
|||
}, |
|||
}, |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool2", |
|||
Arguments: map[string]any{ |
|||
"arg2": "value", |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "multiple tool calls with text after", |
|||
input: ` |
|||
{"name": "tool1", "arguments": {"arg1": 1}} text |
|||
{"name": "tool2", "arguments": {"arg2": "value"}} text |
|||
`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool1", |
|||
Arguments: map[string]any{ |
|||
"arg1": float64(1), |
|||
}, |
|||
}, |
|||
}, |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool2", |
|||
Arguments: map[string]any{ |
|||
"arg2": "value", |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "second tool call in array", |
|||
input: ` |
|||
, {"name": "tool2", "arguments": {"arg2": "value"}} |
|||
`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool2", |
|||
Arguments: map[string]any{ |
|||
"arg2": "value", |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
// a bad JSON would not return any tool calls or content as it would always accumulate more
|
|||
{ |
|||
name: "unbalanced square brackets", |
|||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: nil, |
|||
wantErr: errAccumulateMore, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "incomplete square brackets", |
|||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: nil, |
|||
wantErr: errAccumulateMore, |
|||
prefix: "", |
|||
}, |
|||
{ |
|||
name: "nested arrays in arguments", |
|||
input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, |
|||
nameField: "name", |
|||
argsField: "arguments", |
|||
wantToolCalls: []api.ToolCall{ |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "tool1", |
|||
Arguments: map[string]any{ |
|||
"arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
wantErr: nil, |
|||
prefix: "", |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range tests { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) |
|||
|
|||
if err != tt.wantErr { |
|||
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) |
|||
} |
|||
|
|||
if len(gotCalls) != 0 && tt.wantErr != nil { |
|||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) |
|||
} |
|||
|
|||
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { |
|||
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestParseToolCalls(t *testing.T) { |
|||
p := filepath.Join("testdata") |
|||
t1 := api.ToolCall{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "get_current_weather", |
|||
Arguments: api.ToolCallFunctionArguments{ |
|||
"format": "fahrenheit", |
|||
"location": "San Francisco, CA", |
|||
}, |
|||
}, |
|||
} |
|||
t2 := api.ToolCall{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "get_current_weather", |
|||
Arguments: api.ToolCallFunctionArguments{ |
|||
"format": "celsius", |
|||
"location": "Toronto, Canada", |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
cases := []struct { |
|||
name string |
|||
model string |
|||
output string |
|||
expectedToolCall []api.ToolCall |
|||
expectedTokens string |
|||
}{ |
|||
{ |
|||
name: "mistral malformed json with tool calls prefix", |
|||
model: "mistral", |
|||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "mistral multiple tool calls without prefix", |
|||
model: "mistral", |
|||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "mistral tool calls with text between no prefix", |
|||
model: "mistral", |
|||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] |
|||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
}, |
|||
{ |
|||
name: "mistral valid json with tool calls prefix", |
|||
model: "mistral", |
|||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "mistral multiple tool calls with text between and prefix", |
|||
model: "mistral", |
|||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] |
|||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "mistral incomplete json with tool calls prefix", |
|||
model: "mistral", |
|||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "mistral invalid tool call with explanatory text no prefix", |
|||
model: "mistral", |
|||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: |
|||
|
|||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
}, |
|||
{ |
|||
name: "mistral tool calls without prefix", |
|||
model: "mistral", |
|||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "command r plus tool calls with json block format", |
|||
model: "command-r-plus", |
|||
output: "Action: ```json" + ` |
|||
[ |
|||
{ |
|||
"tool_name": "get_current_weather", |
|||
"parameters": { |
|||
"format": "fahrenheit", |
|||
"location": "San Francisco, CA" |
|||
} |
|||
}, |
|||
{ |
|||
"tool_name": "get_current_weather", |
|||
"parameters": { |
|||
"format": "celsius", |
|||
"location": "Toronto, Canada" |
|||
} |
|||
} |
|||
] |
|||
` + "```", |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "firefunction tool calls with functools prefix", |
|||
model: "firefunction", |
|||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "llama3 groq single tool call with xml tags", |
|||
model: "llama3-groq-tool-use", |
|||
output: `<tool_call> |
|||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} |
|||
</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "xlam tool calls with wrapper object", |
|||
model: "xlam", |
|||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 single tool call with prefix", |
|||
model: "qwen2.5", |
|||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 multiple tool calls with and without prefix", |
|||
model: "qwen2.5", |
|||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1, t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 plain text response no tool calls", |
|||
model: "qwen2.5", |
|||
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls with trailing text", |
|||
model: "qwen2.5", |
|||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "some tokens after call", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls with initial text", |
|||
model: "qwen2.5", |
|||
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls with prefix and trailing text", |
|||
model: "qwen2.5", |
|||
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls with prefix and initial text", |
|||
model: "qwen2.5", |
|||
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "some tokens before call", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls without and with prefix", |
|||
model: "qwen2.5", |
|||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls without and with prefix and text between", |
|||
model: "qwen2.5", |
|||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call> some tokens after call`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "some tokens between", |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", |
|||
model: "qwen2.5", |
|||
output: `hi [{"options": "foo"}]`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `hi [{"options": "foo"}]`, |
|||
}, |
|||
{ |
|||
name: "qwen2.5 tool calls with prefix and invalid tool call", |
|||
model: "qwen2.5", |
|||
output: `<tool_call> [{"options": "foo"}] </tool_call> `, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: ``, |
|||
}, |
|||
{ |
|||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", |
|||
model: "qwen3", |
|||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", |
|||
}, |
|||
{ |
|||
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", |
|||
model: "qwen3", |
|||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", |
|||
}, |
|||
{ |
|||
name: "qwen3 empty think prefix without tool prefix and invalid tool call", |
|||
model: "qwen3", |
|||
output: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
}, |
|||
{ |
|||
name: "qwen3 empty think prefix with tool prefix and valid tool call", |
|||
model: "qwen3", |
|||
output: `<think></think><tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: `<think></think>`, |
|||
}, |
|||
{ |
|||
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", |
|||
model: "qwen3", |
|||
output: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
}, |
|||
{ |
|||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", |
|||
model: "qwen3", |
|||
output: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
}, |
|||
{ |
|||
name: "qwen3 invalid tool call with malformed tool prefix", |
|||
model: "qwen3", |
|||
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, |
|||
}, |
|||
{ |
|||
name: "model with prefix in template, no prefix in output", |
|||
model: "qwen2.5", |
|||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model with prefix in template, prefix in output", |
|||
model: "qwen2.5", |
|||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, no prefix in output", |
|||
model: "llama3.2", |
|||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, no prefix in output, single tool call", |
|||
model: "llama3.2", |
|||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, |
|||
expectedToolCall: []api.ToolCall{t1}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, prefix in output", |
|||
model: "llama3.2", |
|||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`, |
|||
}, |
|||
{ |
|||
name: "model with prefix in template, no prefix in output, tokens before", |
|||
model: "qwen2.5", |
|||
output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
}, |
|||
{ |
|||
name: "model with prefix in template, prefix in output, tokens after", |
|||
model: "qwen2.5", |
|||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, no prefix in output, tokens after", |
|||
model: "llama3.2", |
|||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, |
|||
expectedToolCall: []api.ToolCall{t1, t2}, |
|||
expectedTokens: "", |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, no prefix in output, tokens before", |
|||
model: "llama3.2", |
|||
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, |
|||
}, |
|||
{ |
|||
name: "model without prefix in template, prefix in output, tokens after", |
|||
model: "llama3.2", |
|||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, |
|||
expectedToolCall: []api.ToolCall{}, |
|||
expectedTokens: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, |
|||
}, |
|||
} |
|||
|
|||
var tools []api.Tool |
|||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
var messages []api.Message |
|||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
t.Run("template", func(t *testing.T) { |
|||
actual := &bytes.Buffer{} // Create new buffer for each test
|
|||
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { |
|||
t.Errorf("mismatch (-got +want):\n%s", diff) |
|||
} |
|||
}) |
|||
|
|||
t.Run("parse", func(t *testing.T) { |
|||
tp, err := NewParser(tmpl.Template) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
got := []api.ToolCall{} |
|||
var gotTokens strings.Builder |
|||
|
|||
tokens := strings.Fields(tt.output) |
|||
for _, tok := range tokens { |
|||
s := " " + tok |
|||
|
|||
toolCalls, content := tp.Add(s) |
|||
if len(content) > 0 { |
|||
gotTokens.WriteString(content) |
|||
} else if len(toolCalls) > 0 { |
|||
got = append(got, toolCalls...) |
|||
} |
|||
} |
|||
|
|||
// Compare tool calls if we expect any
|
|||
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { |
|||
t.Errorf("tool calls mismatch (-got +want):\n%s", diff) |
|||
} |
|||
|
|||
// Compare tokens if we expect any
|
|||
stripped := strings.TrimSpace(gotTokens.String()) |
|||
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { |
|||
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) |
|||
t.Errorf("tokens mismatch (-got +want):\n%s", diff) |
|||
} |
|||
}) |
|||
}) |
|||
} |
|||
} |
|||
@ -0,0 +1,227 @@ |
|||
package tools |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/json" |
|||
"errors" |
|||
"log/slog" |
|||
"slices" |
|||
"strings" |
|||
gotmpl "text/template" |
|||
"text/template/parse" |
|||
|
|||
"github.com/ollama/ollama/api" |
|||
"github.com/ollama/ollama/template" |
|||
) |
|||
|
|||
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
|||
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
|||
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
|||
//
|
|||
// Returns:
|
|||
// - string: The extracted text following the first ".ToolCalls" condition found
|
|||
// - bool: Whether a ".ToolCalls" condition was found in the template
|
|||
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { |
|||
if tmpl == nil || tmpl.Tree == nil { |
|||
slog.Debug("template or tree is nil") |
|||
return "", false |
|||
} |
|||
|
|||
var result string |
|||
var found bool |
|||
|
|||
var walk func(nodes []parse.Node) |
|||
walk = func(nodes []parse.Node) { |
|||
for _, node := range nodes { |
|||
if found { |
|||
return |
|||
} |
|||
|
|||
switch n := node.(type) { |
|||
case *parse.IfNode: |
|||
if isToolCallsNode(n) { |
|||
// Collect immediate TextNode(s) at start of IfNode's list
|
|||
var sb strings.Builder |
|||
for _, innerNode := range n.List.Nodes { |
|||
if tn, ok := innerNode.(*parse.TextNode); ok { |
|||
sb.Write(tn.Text) |
|||
} else { |
|||
// Stop at first non-text node
|
|||
break |
|||
} |
|||
} |
|||
result = sb.String() |
|||
found = true |
|||
return |
|||
} |
|||
// Recurse into child nodes
|
|||
walk(n.List.Nodes) |
|||
if n.ElseList != nil { |
|||
walk(n.ElseList.Nodes) |
|||
} |
|||
case *parse.ListNode: |
|||
walk(n.Nodes) |
|||
case *parse.RangeNode: |
|||
walk(n.List.Nodes) |
|||
if n.ElseList != nil { |
|||
walk(n.ElseList.Nodes) |
|||
} |
|||
case *parse.WithNode: |
|||
walk(n.List.Nodes) |
|||
if n.ElseList != nil { |
|||
walk(n.ElseList.Nodes) |
|||
} |
|||
default: |
|||
// Continue to next node
|
|||
continue |
|||
} |
|||
} |
|||
} |
|||
|
|||
walk(tmpl.Tree.Root.Nodes) |
|||
return result, found |
|||
} |
|||
|
|||
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
|||
func isToolCallsNode(n *parse.IfNode) bool { |
|||
for _, cmd := range n.Pipe.Cmds { |
|||
for _, arg := range cmd.Args { |
|||
if field, ok := arg.(*parse.FieldNode); ok { |
|||
if slices.Contains(field.Ident, "ToolCalls") { |
|||
return true |
|||
} |
|||
} |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func toolPrefix(tmpl *gotmpl.Template) string { |
|||
tokenText, ok := extractToolCallsFormat(tmpl) |
|||
if !ok { |
|||
return "" |
|||
} |
|||
tokenText = strings.TrimSpace(tokenText) |
|||
tokenText = strings.ReplaceAll(tokenText, "\r", "") |
|||
tokenText = strings.ReplaceAll(tokenText, "\n", " ") |
|||
|
|||
return tokenText |
|||
} |
|||
|
|||
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
|||
//
|
|||
// Returns:
|
|||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
|||
// - error: Error if parsing failed
|
|||
func toolTemplate(t *template.Template) (*gotmpl.Template, error) { |
|||
tmpl := t.Subtree(func(n parse.Node) bool { |
|||
if t, ok := n.(*parse.RangeNode); ok { |
|||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") |
|||
} |
|||
|
|||
return false |
|||
}) |
|||
|
|||
if tmpl == nil { |
|||
return nil, errors.New("failed to find tool template") |
|||
} |
|||
|
|||
return tmpl, nil |
|||
} |
|||
|
|||
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
|
|||
//
|
|||
// Returns:
|
|||
// - int: The starting index in s where the suffix overlap begins
|
|||
func suffixOverlap(s, prefix string) int { |
|||
max := min(len(prefix), len(s)) |
|||
for i := max; i > 0; i-- { |
|||
if strings.HasSuffix(s, prefix[:i]) { |
|||
return len(s) - i |
|||
} |
|||
} |
|||
return -1 |
|||
} |
|||
|
|||
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
|||
//
|
|||
// Returns:
|
|||
// - string: The name of the tool call
|
|||
// - string: The arguments of the tool call
|
|||
// - error: Error if parsing failed
|
|||
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { |
|||
var b bytes.Buffer |
|||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{ |
|||
"ToolCalls": { |
|||
{ |
|||
Function: api.ToolCallFunction{ |
|||
Name: "@@name@@", |
|||
Arguments: api.ToolCallFunctionArguments{ |
|||
"@@argument@@": 1, |
|||
}, |
|||
}, |
|||
}, |
|||
}, |
|||
}); err != nil { |
|||
return "", "", err |
|||
} |
|||
|
|||
var obj any |
|||
err = json.Unmarshal(b.Bytes(), &obj) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
|
|||
var objs []map[string]any |
|||
switch v := obj.(type) { |
|||
case map[string]any: |
|||
objs = []map[string]any{v} |
|||
case []map[string]any: |
|||
objs = v |
|||
case []any: |
|||
objs = collect(v) |
|||
} |
|||
if len(objs) == 0 { |
|||
return "", "", errors.New("no template objects found") |
|||
} |
|||
|
|||
// find the keys that correspond to the name and arguments fields
|
|||
for k, v := range objs[0] { |
|||
switch v.(type) { |
|||
case string: |
|||
name = k |
|||
case map[string]any: |
|||
arguments = k |
|||
} |
|||
} |
|||
|
|||
if name == "" || arguments == "" { |
|||
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) |
|||
return "", "", errors.New("missing required fields in tool call template") |
|||
} |
|||
|
|||
return name, arguments, nil |
|||
} |
|||
|
|||
// collect recursively traverses an object to collect all nested maps
|
|||
//
|
|||
// Returns:
|
|||
// - []map[string]any: A slice of all nested maps found in the object
|
|||
func collect(obj any) []map[string]any { |
|||
var all []map[string]any |
|||
switch o := obj.(type) { |
|||
case map[string]any: |
|||
all = append(all, o) |
|||
for _, v := range o { |
|||
all = append(all, collect(v)...) |
|||
} |
|||
case []any: |
|||
for _, v := range o { |
|||
all = append(all, collect(v)...) |
|||
} |
|||
default: |
|||
return nil |
|||
} |
|||
|
|||
return all |
|||
} |
|||
@ -0,0 +1,464 @@ |
|||
package tools |
|||
|
|||
import ( |
|||
"testing" |
|||
gotmpl "text/template" |
|||
|
|||
"github.com/ollama/ollama/template" |
|||
) |
|||
|
|||
func TestExtractToolCallsFormat(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
template string |
|||
want string |
|||
found bool |
|||
}{ |
|||
{ |
|||
name: "nil template", |
|||
template: "", |
|||
want: "", |
|||
found: false, |
|||
}, |
|||
{ |
|||
name: "basic tool call with text", |
|||
template: "{{if .ToolCalls}}Hello world{{end}}", |
|||
want: "Hello world", |
|||
found: true, |
|||
}, |
|||
{ |
|||
name: "tool call with json format", |
|||
template: "{{if .ToolCalls}}```json\n{{end}}", |
|||
want: "```json\n", |
|||
found: true, |
|||
}, |
|||
{ |
|||
name: "tool call in range", |
|||
template: "{{range .ToolCalls}}tool: {{.}}{{end}}", |
|||
want: "", |
|||
found: false, |
|||
}, |
|||
{ |
|||
name: "tool call with multiple text nodes", |
|||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", |
|||
want: "First text", |
|||
found: true, |
|||
}, |
|||
{ |
|||
name: "nested if without tool calls", |
|||
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", |
|||
want: "", |
|||
found: false, |
|||
}, |
|||
} |
|||
|
|||
for _, tc := range cases { |
|||
t.Run(tc.name, func(t *testing.T) { |
|||
tmpl, err := gotmpl.New("test").Parse(tc.template) |
|||
if err != nil && tc.template != "" { |
|||
t.Fatalf("failed to parse template: %v", err) |
|||
} |
|||
|
|||
got, found := extractToolCallsFormat(tmpl) |
|||
if got != tc.want { |
|||
t.Errorf("got text %q, want %q", got, tc.want) |
|||
} |
|||
if found != tc.found { |
|||
t.Errorf("got found %v, want %v", found, tc.found) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestToolPrefix(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
template string |
|||
want string |
|||
}{ |
|||
{ |
|||
name: "basic tool call with action prefix", |
|||
template: "{{if .ToolCalls}}Action: ```json{{end}}", |
|||
want: "Action: ```json", |
|||
}, |
|||
{ |
|||
name: "incomplete functools bracket", |
|||
template: "{{if .ToolCalls}}functools[{{end}}", |
|||
want: "functools[", |
|||
}, |
|||
{ |
|||
name: "tool call with angle brackets", |
|||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}", |
|||
want: "Hello, world! <tool_call>", |
|||
}, |
|||
{ |
|||
name: "multiple tool call formats", |
|||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}", |
|||
want: "[tool_call] <tool_call>", |
|||
}, |
|||
{ |
|||
name: "single angle bracket tool call", |
|||
template: "{{if .ToolCalls}}<tool_call>{{end}}", |
|||
want: "<tool_call>", |
|||
}, |
|||
{ |
|||
name: "incomplete angle bracket after tool call", |
|||
template: "{{if .ToolCalls}}[tool_call] <{{end}}", |
|||
want: "[tool_call] <", |
|||
}, |
|||
{ |
|||
name: "angle bracket prefix with tool call", |
|||
template: "{{if .ToolCalls}}> <tool_call>{{end}}", |
|||
want: "> <tool_call>", |
|||
}, |
|||
{ |
|||
name: "uppercase tool call with incomplete bracket", |
|||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", |
|||
want: "[TOOL_CALL] [", |
|||
}, |
|||
{ |
|||
name: "uppercase tool call with adjacent bracket", |
|||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", |
|||
want: "[TOOL_CALL][", |
|||
}, |
|||
{ |
|||
name: "tool call with pipe delimiters", |
|||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}", |
|||
want: "<|tool_call|>", |
|||
}, |
|||
{ |
|||
name: "tool with no prefix", |
|||
template: "{{if .ToolCalls}}{{end}}", |
|||
want: "", |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tmpl, err := gotmpl.New("test").Parse(tt.template) |
|||
if err != nil { |
|||
t.Fatalf("failed to parse template: %v", err) |
|||
} |
|||
got := toolPrefix(tmpl) |
|||
if got != tt.want { |
|||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestToolTemplate(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
template string |
|||
want bool |
|||
}{ |
|||
{ |
|||
name: "basic tool call range", |
|||
template: "{{range .ToolCalls}}test{{end}}", |
|||
want: true, |
|||
}, |
|||
{ |
|||
name: "no tool calls", |
|||
template: "{{range .Other}}test{{end}}", |
|||
want: false, |
|||
}, |
|||
{ |
|||
name: "nested tool calls", |
|||
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", |
|||
want: true, |
|||
}, |
|||
{ |
|||
name: "empty template", |
|||
template: "", |
|||
want: false, |
|||
}, |
|||
{ |
|||
name: "tool calls in if statement", |
|||
template: "{{if .ToolCalls}}test{{end}}", |
|||
want: false, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tmpl, err := gotmpl.New("test").Parse(tt.template) |
|||
if err != nil { |
|||
t.Fatalf("failed to parse template: %v", err) |
|||
} |
|||
|
|||
parsed, err := template.Parse(tmpl.Root.String()) |
|||
if err != nil { |
|||
t.Fatalf("failed to parse template: %v", err) |
|||
} |
|||
|
|||
_, err = toolTemplate(parsed) |
|||
if err != nil && tt.want { |
|||
t.Errorf("toolTemplate() = %v; want %v", err, tt.want) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestSuffixOverlap(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
s string |
|||
d string |
|||
want int |
|||
}{ |
|||
{ |
|||
name: "no overlap", |
|||
s: "hello world", |
|||
d: "<tool_call>", |
|||
want: -1, |
|||
}, |
|||
{ |
|||
name: "full overlap", |
|||
s: "<tool_call>", |
|||
d: "<tool_call>", |
|||
want: 0, |
|||
}, |
|||
{ |
|||
name: "partial overlap", |
|||
s: "text <tool_call>", |
|||
d: "<tool_call>", |
|||
want: 5, |
|||
}, |
|||
{ |
|||
name: "delimiter longer than string", |
|||
s: "<tool>", |
|||
d: "<tool_call>", |
|||
want: -1, |
|||
}, |
|||
{ |
|||
name: "empty string", |
|||
s: "", |
|||
d: "<tool_call>", |
|||
want: -1, |
|||
}, |
|||
{ |
|||
name: "empty delimiter", |
|||
s: "<tool_call>", |
|||
d: "", |
|||
want: -1, |
|||
}, |
|||
{ |
|||
name: "single char overlap", |
|||
s: "test<", |
|||
d: "<tool_call>", |
|||
want: 4, |
|||
}, |
|||
{ |
|||
name: "partial tool call", |
|||
s: "hello <tool_", |
|||
d: "<tool_call>", |
|||
want: 6, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
got := suffixOverlap(tt.s, tt.d) |
|||
if got != tt.want { |
|||
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestExtractToolArgs(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
template string |
|||
want string |
|||
ok bool |
|||
}{ |
|||
{ |
|||
name: "basic tool call with text after", |
|||
template: `{{if .ToolCalls}}tool response{{end}}`, |
|||
want: "tool response", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "tool call with mixed content after", |
|||
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`, |
|||
want: "<tool_call>", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "tool call with no text after", |
|||
template: `{{if .ToolCalls}}{{.Something}}{{end}}`, |
|||
want: "", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "nested tool call", |
|||
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, |
|||
want: "[TOOL_CALL]", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "no tool calls", |
|||
template: `{{if .Something}}no tools here{{end}}`, |
|||
want: "", |
|||
ok: false, |
|||
}, |
|||
{ |
|||
name: "empty template", |
|||
template: ``, |
|||
want: "", |
|||
ok: false, |
|||
}, |
|||
{ |
|||
name: "multiple tool calls sections", |
|||
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, |
|||
want: "first", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "range over tool calls", |
|||
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, |
|||
want: "", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "tool calls with pipe delimiters", |
|||
template: `{{if .ToolCalls}}<|tool|>{{end}}`, |
|||
want: "<|tool|>", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "tool calls with nested template", |
|||
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, |
|||
want: "", |
|||
ok: true, |
|||
}, |
|||
{ |
|||
name: "tool calls with whitespace variations", |
|||
template: `{{if .ToolCalls}} tool {{end}}`, |
|||
want: " tool ", |
|||
ok: true, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
tmpl, err := gotmpl.New("test").Parse(tt.template) |
|||
if err != nil { |
|||
t.Fatalf("failed to parse template: %v", err) |
|||
} |
|||
|
|||
got, ok := extractToolCallsFormat(tmpl) |
|||
if got != tt.want { |
|||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) |
|||
} |
|||
if ok != tt.ok { |
|||
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
func TestCollect(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
obj any |
|||
want []map[string]any |
|||
}{ |
|||
{ |
|||
name: "simple map", |
|||
obj: map[string]any{ |
|||
"key": "value", |
|||
}, |
|||
want: []map[string]any{ |
|||
{"key": "value"}, |
|||
}, |
|||
}, |
|||
{ |
|||
name: "nested map", |
|||
obj: map[string]any{ |
|||
"outer": map[string]any{ |
|||
"inner": "value", |
|||
}, |
|||
}, |
|||
want: []map[string]any{ |
|||
{"outer": map[string]any{"inner": "value"}}, |
|||
{"inner": "value"}, |
|||
}, |
|||
}, |
|||
{ |
|||
name: "array of maps", |
|||
obj: []any{ |
|||
map[string]any{"key1": "val1"}, |
|||
map[string]any{"key2": "val2"}, |
|||
}, |
|||
want: []map[string]any{ |
|||
{"key1": "val1"}, |
|||
{"key2": "val2"}, |
|||
}, |
|||
}, |
|||
{ |
|||
name: "deeply nested", |
|||
obj: map[string]any{ |
|||
"l1": map[string]any{ |
|||
"l2": map[string]any{ |
|||
"l3": "value", |
|||
}, |
|||
}, |
|||
}, |
|||
want: []map[string]any{ |
|||
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, |
|||
{"l2": map[string]any{"l3": "value"}}, |
|||
{"l3": "value"}, |
|||
}, |
|||
}, |
|||
{ |
|||
name: "non-map value", |
|||
obj: "string", |
|||
want: nil, |
|||
}, |
|||
} |
|||
|
|||
for _, tt := range cases { |
|||
t.Run(tt.name, func(t *testing.T) { |
|||
got := collect(tt.obj) |
|||
if len(got) != len(tt.want) { |
|||
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) |
|||
return |
|||
} |
|||
|
|||
// Compare each map in the result
|
|||
for i := range tt.want { |
|||
if !mapsEqual(got[i], tt.want[i]) { |
|||
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) |
|||
} |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// mapsEqual compares two maps for deep equality
|
|||
func mapsEqual(m1, m2 map[string]any) bool { |
|||
if len(m1) != len(m2) { |
|||
return false |
|||
} |
|||
for k, v1 := range m1 { |
|||
v2, ok := m2[k] |
|||
if !ok { |
|||
return false |
|||
} |
|||
switch val1 := v1.(type) { |
|||
case map[string]any: |
|||
val2, ok := v2.(map[string]any) |
|||
if !ok || !mapsEqual(val1, val2) { |
|||
return false |
|||
} |
|||
default: |
|||
if v1 != v2 { |
|||
return false |
|||
} |
|||
} |
|||
} |
|||
return true |
|||
} |
|||
Loading…
Reference in new issue