mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
tests: basic benchmarking test framework (#12964)
This change adds a basic benchmarking test framework for Ollama which can be used to determine the prefill, eval, load duration, and total duration for running a given model or models.
This commit is contained in:
114
cmd/bench/README.md
Normal file
114
cmd/bench/README.md
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
Ollama Benchmark Tool
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* Benchmark multiple models in a single run
|
||||||
|
* Support for both text and image prompts
|
||||||
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
|
* Supports benchstat and CSV output formats
|
||||||
|
* Detailed performance metrics (prefill, generate, load, total durations)
|
||||||
|
|
||||||
|
## Building from Source
|
||||||
|
|
||||||
|
```
|
||||||
|
go build -o ollama-bench bench.go
|
||||||
|
./bench -model gpt-oss:20b -epochs 6 -format csv
|
||||||
|
```
|
||||||
|
|
||||||
|
Using Go Run (without building)
|
||||||
|
|
||||||
|
```
|
||||||
|
go run bench.go -model gpt-oss:20b -epochs 3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Example
|
||||||
|
|
||||||
|
```
|
||||||
|
./bench -model gemma3 -epochs 6
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark Multiple Models
|
||||||
|
|
||||||
|
```
|
||||||
|
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||||
|
benchstat -col /name gemma.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Image Prompt
|
||||||
|
|
||||||
|
```
|
||||||
|
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Example
|
||||||
|
|
||||||
|
```
|
||||||
|
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command Line Options
|
||||||
|
|
||||||
|
| Option | Description | Default |
|
||||||
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
|
| -epochs | Number of iterations per model | 1 |
|
||||||
|
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||||
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
|
| -seed | Random seed | 0 (random) |
|
||||||
|
| -timeout | Timeout in seconds | 300 |
|
||||||
|
| -p | Prompt text | "Write a long story." |
|
||||||
|
| -image | Image file to include in prompt | |
|
||||||
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -v | Verbose mode | false |
|
||||||
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
|
## Output Formats
|
||||||
|
|
||||||
|
### Markdown Format
|
||||||
|
|
||||||
|
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
||||||
|
```
|
||||||
|
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
||||||
|
|-------|------|-------|----------|------------|--------------|
|
||||||
|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
||||||
|
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
||||||
|
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
||||||
|
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchstat Format
|
||||||
|
|
||||||
|
Compatible with Go's benchstat tool for statistical analysis:
|
||||||
|
|
||||||
|
```
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
||||||
|
```
|
||||||
|
|
||||||
|
### CSV Format
|
||||||
|
|
||||||
|
Machine-readable comma-separated values:
|
||||||
|
|
||||||
|
```
|
||||||
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
|
gpt-oss:20b,prefill,128,78125.00,12800.00
|
||||||
|
gpt-oss:20b,generate,512,19531.25,51200.00
|
||||||
|
gpt-oss:20b,load,1,1500000000,0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metrics Explained
|
||||||
|
|
||||||
|
The tool reports four types of metrics for each model:
|
||||||
|
|
||||||
|
* prefill: Time spent processing the prompt
|
||||||
|
* generate: Time spent generating the response
|
||||||
|
* load: Model loading time (one-time cost)
|
||||||
|
* total: Total request duration
|
||||||
|
|
||||||
309
cmd/bench/bench.go
Normal file
309
cmd/bench/bench.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type flagOptions struct {
|
||||||
|
models *string
|
||||||
|
epochs *int
|
||||||
|
maxTokens *int
|
||||||
|
temperature *float64
|
||||||
|
seed *int
|
||||||
|
timeout *int
|
||||||
|
prompt *string
|
||||||
|
imageFile *string
|
||||||
|
keepAlive *float64
|
||||||
|
format *string
|
||||||
|
outputFile *string
|
||||||
|
debug *bool
|
||||||
|
verbose *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Metrics struct {
|
||||||
|
Model string
|
||||||
|
Step string
|
||||||
|
Count int
|
||||||
|
Duration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
printHeader := func() {
|
||||||
|
fmt.Printf("sysname: %s\n", runtime.GOOS)
|
||||||
|
fmt.Printf("machine: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
}
|
||||||
|
for _, m := range metrics {
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
if m.Count > 0 {
|
||||||
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||||
|
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
||||||
|
m.Model, m.Step, m.Count)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var suffix string
|
||||||
|
if m.Step == "load" {
|
||||||
|
suffix = "/step=load"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
||||||
|
m.Model, suffix, m.Duration.Nanoseconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
printHeader := func() {
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
|
||||||
|
for _, m := range metrics {
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
var nsPerToken float64
|
||||||
|
var tokensPerSec float64
|
||||||
|
if m.Count > 0 {
|
||||||
|
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "markdown":
|
||||||
|
printHeader := func() {
|
||||||
|
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
||||||
|
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
|
||||||
|
for _, m := range metrics {
|
||||||
|
var nsPerToken, tokensPerSec float64
|
||||||
|
var nsPerTokenStr, tokensPerSecStr string
|
||||||
|
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
||||||
|
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
||||||
|
} else {
|
||||||
|
nsPerTokenStr = "-"
|
||||||
|
tokensPerSecStr = "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
||||||
|
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkChat(fOpt flagOptions) error {
|
||||||
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
|
// todo - add multi-image support
|
||||||
|
var imgData api.ImageData
|
||||||
|
var err error
|
||||||
|
if *fOpt.imageFile != "" {
|
||||||
|
imgData, err = readImage(*fOpt.imageFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug && imgData != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
for range *fOpt.epochs {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.ChatRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: *fOpt.prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Messages[0].Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMetrics *api.Metrics
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
responseMetrics = &resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseMetrics == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := []Metrics{
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "prefill",
|
||||||
|
Count: responseMetrics.PromptEvalCount,
|
||||||
|
Duration: responseMetrics.PromptEvalDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "generate",
|
||||||
|
Count: responseMetrics.EvalCount,
|
||||||
|
Duration: responseMetrics.EvalDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "load",
|
||||||
|
Count: 1,
|
||||||
|
Duration: responseMetrics.LoadDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "total",
|
||||||
|
Count: 1,
|
||||||
|
Duration: responseMetrics.TotalDuration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.ImageData(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fOpt := flagOptions{
|
||||||
|
models: flag.String("model", "", "Model to benchmark"),
|
||||||
|
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||||
|
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||||
|
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||||
|
seed: flag.Int("seed", 0, "Random seed"),
|
||||||
|
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||||
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
|
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
||||||
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
||||||
|
fmt.Fprintf(os.Stderr, "Description:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
|
flag.PrintDefaults()
|
||||||
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
||||||
|
}
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*fOpt.models) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
|
||||||
|
flag.Usage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
BenchmarkChat(fOpt)
|
||||||
|
}
|
||||||
463
cmd/bench/bench_test.go
Normal file
463
cmd/bench/bench_test.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestFlagOptions() flagOptions {
|
||||||
|
models := "test-model"
|
||||||
|
format := "benchstat"
|
||||||
|
epochs := 1
|
||||||
|
maxTokens := 100
|
||||||
|
temperature := 0.7
|
||||||
|
seed := 42
|
||||||
|
timeout := 30
|
||||||
|
prompt := "test prompt"
|
||||||
|
imageFile := ""
|
||||||
|
keepAlive := 5.0
|
||||||
|
verbose := false
|
||||||
|
debug := false
|
||||||
|
|
||||||
|
return flagOptions{
|
||||||
|
models: &models,
|
||||||
|
format: &format,
|
||||||
|
epochs: &epochs,
|
||||||
|
maxTokens: &maxTokens,
|
||||||
|
temperature: &temperature,
|
||||||
|
seed: &seed,
|
||||||
|
timeout: &timeout,
|
||||||
|
prompt: &prompt,
|
||||||
|
imageFile: &imageFile,
|
||||||
|
keepAlive: &keepAlive,
|
||||||
|
verbose: &verbose,
|
||||||
|
debug: &debug,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureOutput(f func()) string {
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
defer func() {
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
os.Stderr = w
|
||||||
|
|
||||||
|
f()
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
io.Copy(&buf, r)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/chat" {
|
||||||
|
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
|
||||||
|
http.Error(w, "Not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != "POST" {
|
||||||
|
t.Errorf("Expected POST method, got %s", r.Method)
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
jsonData, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to marshal response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
|
w.Write([]byte("\n"))
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond) // Simulate some delay
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_Success(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
mockResponses := []api.ChatResponse{
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response part 1",
|
||||||
|
},
|
||||||
|
Done: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response part 2",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := createMockOllamaServer(t, mockResponses)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
|
||||||
|
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
|
||||||
|
t.Errorf("Expected output to contain generate metrics, got: %s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ns/token") {
|
||||||
|
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_ServerError(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
|
||||||
|
t.Errorf("Expected error message about chat failure, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_Timeout(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
shortTimeout := 1 // Very short timeout
|
||||||
|
fOpt.timeout = &shortTimeout
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate a long delay that will cause timeout
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Chat request timed out") {
|
||||||
|
t.Errorf("Expected timeout error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_NoMetrics(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
mockResponses := []api.ChatResponse{
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response",
|
||||||
|
},
|
||||||
|
Done: false, // Never sends Done=true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := createMockOllamaServer(t, mockResponses)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: No metrics received") {
|
||||||
|
t.Errorf("Expected no metrics error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_MultipleModels(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
models := "model1,model2"
|
||||||
|
epochs := 2
|
||||||
|
fOpt.models = &models
|
||||||
|
fOpt.epochs = &epochs
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
var req api.ChatRequest
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.Unmarshal(body, &req)
|
||||||
|
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response for " + req.Model,
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should be called 4 times (2 models × 2 epochs)
|
||||||
|
if callCount != 4 {
|
||||||
|
t.Errorf("Expected 4 API calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
|
||||||
|
t.Errorf("Expected output for both models, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_WithImage(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
content := []byte("fake image data")
|
||||||
|
if _, err := tmpfile.Write(content); err != nil {
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
tmpfile.Close()
|
||||||
|
|
||||||
|
tmpfileName := tmpfile.Name()
|
||||||
|
fOpt.imageFile = &tmpfileName
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify the request contains image data
|
||||||
|
var req api.ChatRequest
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.Unmarshal(body, &req)
|
||||||
|
|
||||||
|
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
|
||||||
|
t.Error("Expected request to contain images")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response with image",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
|
||||||
|
t.Errorf("Expected benchmark output, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_ImageError(t *testing.T) {
|
||||||
|
randFileName := func() string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
const length = 8
|
||||||
|
|
||||||
|
result := make([]byte, length)
|
||||||
|
rand.Read(result) // Fill with random bytes
|
||||||
|
|
||||||
|
for i := range result {
|
||||||
|
result[i] = charset[result[i]%byte(len(charset))]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result) + ".txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
imageFile := randFileName()
|
||||||
|
fOpt.imageFile = &imageFile
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from image reading, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Couldn't read image") {
|
||||||
|
t.Errorf("Expected image read error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadImage_Success(t *testing.T) {
|
||||||
|
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
content := []byte("fake image data")
|
||||||
|
if _, err := tmpfile.Write(content); err != nil {
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
tmpfile.Close()
|
||||||
|
|
||||||
|
imgData, err := readImage(tmpfile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData == nil {
|
||||||
|
t.Error("Expected image data, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := api.ImageData(content)
|
||||||
|
if string(imgData) != string(expected) {
|
||||||
|
t.Errorf("Expected image data %v, got %v", expected, imgData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadImage_FileNotFound(t *testing.T) {
|
||||||
|
imgData, err := readImage("nonexistentfile.jpg")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent file, got nil")
|
||||||
|
}
|
||||||
|
if imgData != nil {
|
||||||
|
t.Error("Expected nil image data for non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsMapCreation(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
if options["num_predict"] != *fOpt.maxTokens {
|
||||||
|
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
|
||||||
|
}
|
||||||
|
if options["temperature"] != *fOpt.temperature {
|
||||||
|
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
|
||||||
|
}
|
||||||
|
if options["seed"] != *fOpt.seed {
|
||||||
|
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user