mirror of https://gitee.com/namelin2022/ollama
committed by
GitHub
2 changed files with 0 additions and 237 deletions
@ -1,178 +0,0 @@ |
|||
package benchmark |
|||
|
|||
import ( |
|||
"context" |
|||
"flag" |
|||
"fmt" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/ollama/ollama/api" |
|||
) |
|||
|
|||
// Command line flags
|
|||
var modelFlag string |
|||
|
|||
func init() { |
|||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark") |
|||
flag.Lookup("m").DefValue = "model" |
|||
} |
|||
|
|||
// modelName returns the model name from flags, failing the test if not set
|
|||
func modelName(b *testing.B) string { |
|||
if modelFlag == "" { |
|||
b.Fatal("Error: -m flag is required for benchmark tests") |
|||
} |
|||
return modelFlag |
|||
} |
|||
|
|||
type TestCase struct { |
|||
name string |
|||
prompt string |
|||
maxTokens int |
|||
} |
|||
|
|||
// runGenerateBenchmark contains the common generate and metrics logic
|
|||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) { |
|||
start := time.Now() |
|||
var ttft time.Duration |
|||
var metrics api.Metrics |
|||
|
|||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error { |
|||
if ttft == 0 && resp.Response != "" { |
|||
ttft = time.Since(start) |
|||
} |
|||
if resp.Done { |
|||
metrics = resp.Metrics |
|||
} |
|||
return nil |
|||
}) |
|||
|
|||
// Report custom metrics as part of the benchmark results
|
|||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms") |
|||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms") |
|||
|
|||
// Token throughput metrics
|
|||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds() |
|||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds() |
|||
b.ReportMetric(promptThroughput, "prompt_tok/s") |
|||
b.ReportMetric(genThroughput, "gen_tok/s") |
|||
|
|||
// Token counts
|
|||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens") |
|||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens") |
|||
if err != nil { |
|||
b.Fatal(err) |
|||
} |
|||
} |
|||
|
|||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
|||
func BenchmarkColdStart(b *testing.B) { |
|||
client := setup(b) |
|||
tests := []TestCase{ |
|||
{"short_prompt", "Write a long story", 100}, |
|||
{"medium_prompt", "Write a detailed economic analysis", 500}, |
|||
{"long_prompt", "Write a comprehensive AI research paper", 1000}, |
|||
} |
|||
m := modelName(b) |
|||
|
|||
for _, tt := range tests { |
|||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) { |
|||
ctx := b.Context() |
|||
|
|||
// Set number of tokens as our throughput metric
|
|||
b.SetBytes(int64(tt.maxTokens)) |
|||
|
|||
for b.Loop() { |
|||
b.StopTimer() |
|||
// Ensure model is unloaded before each iteration
|
|||
unload(client, m, b) |
|||
b.StartTimer() |
|||
|
|||
req := &api.GenerateRequest{ |
|||
Model: m, |
|||
Prompt: tt.prompt, |
|||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1}, |
|||
} |
|||
|
|||
runGenerateBenchmark(b, ctx, client, req) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
|||
func BenchmarkWarmStart(b *testing.B) { |
|||
client := setup(b) |
|||
tests := []TestCase{ |
|||
{"short_prompt", "Write a long story", 100}, |
|||
{"medium_prompt", "Write a detailed economic analysis", 500}, |
|||
{"long_prompt", "Write a comprehensive AI research paper", 1000}, |
|||
} |
|||
m := modelName(b) |
|||
|
|||
for _, tt := range tests { |
|||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) { |
|||
ctx := b.Context() |
|||
|
|||
// Pre-warm the model
|
|||
warmup(client, m, tt.prompt, b) |
|||
|
|||
// Set number of tokens as our throughput metric
|
|||
b.SetBytes(int64(tt.maxTokens)) |
|||
|
|||
for b.Loop() { |
|||
req := &api.GenerateRequest{ |
|||
Model: m, |
|||
Prompt: tt.prompt, |
|||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1}, |
|||
} |
|||
|
|||
runGenerateBenchmark(b, ctx, client, req) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
// setup verifies server and model availability
|
|||
func setup(b *testing.B) *api.Client { |
|||
client, err := api.ClientFromEnvironment() |
|||
if err != nil { |
|||
b.Fatal(err) |
|||
} |
|||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil { |
|||
b.Fatalf("Model unavailable: %v", err) |
|||
} |
|||
|
|||
return client |
|||
} |
|||
|
|||
// warmup ensures the model is loaded and warmed up
|
|||
func warmup(client *api.Client, model string, prompt string, b *testing.B) { |
|||
for range 3 { |
|||
err := client.Generate( |
|||
context.Background(), |
|||
&api.GenerateRequest{ |
|||
Model: model, |
|||
Prompt: prompt, |
|||
Options: map[string]any{"num_predict": 50, "temperature": 0.1}, |
|||
}, |
|||
func(api.GenerateResponse) error { return nil }, |
|||
) |
|||
if err != nil { |
|||
b.Logf("Error during model warm-up: %v", err) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// unload forces model unloading using KeepAlive: 0 parameter
|
|||
func unload(client *api.Client, model string, b *testing.B) { |
|||
req := &api.GenerateRequest{ |
|||
Model: model, |
|||
KeepAlive: &api.Duration{Duration: 0}, |
|||
} |
|||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil { |
|||
b.Logf("Unload error: %v", err) |
|||
} |
|||
time.Sleep(1 * time.Second) |
|||
} |
|||
@ -1,59 +0,0 @@ |
|||
# Benchmark |
|||
|
|||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes. |
|||
|
|||
## When to use |
|||
|
|||
Run these benchmarks when: |
|||
- Making changes to the model inference engine |
|||
- Modifying model loading/unloading logic |
|||
- Changing prompt processing or token generation code |
|||
- Implementing a new model architecture |
|||
- Testing performance across different hardware setups |
|||
|
|||
## Prerequisites |
|||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434` |
|||
## Usage and Examples |
|||
|
|||
>[!NOTE] |
|||
>All commands must be run from the root directory of the Ollama project. |
|||
|
|||
Basic syntax: |
|||
```bash |
|||
go test -bench=. ./benchmark/... -m $MODEL_NAME |
|||
``` |
|||
|
|||
Required flags: |
|||
- `-bench=.`: Run all benchmarks |
|||
- `-m`: Model name to benchmark |
|||
|
|||
Optional flags: |
|||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis) |
|||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes) |
|||
|
|||
Common usage patterns: |
|||
|
|||
Single benchmark run with a model specified: |
|||
```bash |
|||
go test -bench=. ./benchmark/... -m llama3.3 |
|||
``` |
|||
|
|||
## Output metrics |
|||
|
|||
The benchmark reports several key metrics: |
|||
|
|||
- `gen_tok/s`: Generated tokens per second |
|||
- `prompt_tok/s`: Prompt processing tokens per second |
|||
- `ttft_ms`: Time to first token in milliseconds |
|||
- `load_ms`: Model load time in milliseconds |
|||
- `gen_tokens`: Total tokens generated |
|||
- `prompt_tokens`: Total prompt tokens processed |
|||
|
|||
Each benchmark runs two scenarios: |
|||
- Cold start: Model is loaded from disk for each test |
|||
- Warm start: Model is pre-loaded in memory |
|||
|
|||
Three prompt lengths are tested for each scenario: |
|||
- Short prompt (100 tokens) |
|||
- Medium prompt (500 tokens) |
|||
- Long prompt (1000 tokens) |
|||
Loading…
Reference in new issue