diff --git a/cmd/bench/README.md b/cmd/bench/README.md new file mode 100644 index 00000000..210cc4a2 --- /dev/null +++ b/cmd/bench/README.md @@ -0,0 +1,114 @@ +Ollama Benchmark Tool +--------------------- + +A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats. + +## Features + + * Benchmark multiple models in a single run + * Support for both text and image prompts + * Configurable generation parameters (temperature, max tokens, seed, etc.) + * Supports benchstat and CSV output formats + * Detailed performance metrics (prefill, generate, load, total durations) + +## Building from Source + +``` +go build -o ollama-bench bench.go +./bench -model gpt-oss:20b -epochs 6 -format csv +``` + +Using Go Run (without building) + +``` +go run bench.go -model gpt-oss:20b -epochs 3 +``` + +## Usage + +### Basic Example + +``` +./bench -model gemma3 -epochs 6 +``` + +### Benchmark Multiple Models + +``` +./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench +benchstat -col /name gemma.bench +``` + +### With Image Prompt + +``` +./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" +``` + +### Advanced Example + +``` +./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv +``` + +## Command Line Options + +| Option | Description | Default | +| -model | Comma-separated list of models to benchmark | (required) | +| -epochs | Number of iterations per model | 1 | +| -max-tokens | Maximum tokens for model response | 0 (unlimited) | +| -temperature | Temperature parameter | 0.0 | +| -seed | Random seed | 0 (random) | +| -timeout | Timeout in seconds | 300 | +| -p | Prompt text | "Write a long story." | +| -image | Image file to include in prompt | | +| -k | Keep-alive duration in seconds | 0 | +| -format | Output format (benchstat, csv) | benchstat | +| -output | Output file for results | "" (stdout) | +| -v | Verbose mode | false | +| -debug | Show debug information | false | + +## Output Formats + +### Markdown Format + +The default markdown format is suitable for copying and pasting into a GitHub issue and will look like: +``` + Model | Step | Count | Duration | nsPerToken | tokensPerSec | +|-------|------|-------|----------|------------|--------------| +| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 | +| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 | +| gpt-oss:20b | load | 1 | 121.674208ms | - | - | +| gpt-oss:20b | total | 1 | 2.861047625s | - | - | +``` + +### Benchstat Format + +Compatible with Go's benchstat tool for statistical analysis: + +``` +BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec +BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec +BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request +``` + +### CSV Format + +Machine-readable comma-separated values: + +``` +NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC +gpt-oss:20b,prefill,128,78125.00,12800.00 +gpt-oss:20b,generate,512,19531.25,51200.00 +gpt-oss:20b,load,1,1500000000,0 +``` + +## Metrics Explained + +The tool reports four types of metrics for each model: + + * prefill: Time spent processing the prompt + * generate: Time spent generating the response + * load: Model loading time (one-time cost) + * total: Total request duration + diff --git a/cmd/bench/bench.go b/cmd/bench/bench.go new file mode 100644 index 00000000..25df1817 --- /dev/null +++ b/cmd/bench/bench.go @@ -0,0 +1,309 @@ +package main + +import ( + "cmp" + "context" + "flag" + "fmt" + "io" + "os" + "runtime" + "slices" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/api" +) + +type flagOptions struct { + models *string + epochs *int + maxTokens *int + temperature *float64 + seed *int + timeout *int + prompt *string + imageFile *string + keepAlive *float64 + format *string + outputFile *string + debug *bool + verbose *bool +} + +type Metrics struct { + Model string + Step string + Count int + Duration time.Duration +} + +var once sync.Once + +const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.` + +func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) { + switch format { + case "benchstat": + if verbose { + printHeader := func() { + fmt.Printf("sysname: %s\n", runtime.GOOS) + fmt.Printf("machine: %s\n", runtime.GOARCH) + } + once.Do(printHeader) + } + for _, m := range metrics { + if m.Step == "generate" || m.Step == "prefill" { + if m.Count > 0 { + nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count) + tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 + + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n", + m.Model, m.Step, m.Count, nsPerToken, tokensPerSec) + } else { + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n", + m.Model, m.Step, m.Count) + } + } else { + var suffix string + if m.Step == "load" { + suffix = "/step=load" + } + fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n", + m.Model, suffix, m.Duration.Nanoseconds()) + } + } + case "csv": + printHeader := func() { + headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"} + fmt.Fprintln(w, strings.Join(headings, ",")) + } + once.Do(printHeader) + + for _, m := range metrics { + if m.Step == "generate" || m.Step == "prefill" { + var nsPerToken float64 + var tokensPerSec float64 + if m.Count > 0 { + nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count) + tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 + } + fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec) + } else { + fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds()) + } + } + case "markdown": + printHeader := func() { + fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |") + fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|") + } + once.Do(printHeader) + + for _, m := range metrics { + var nsPerToken, tokensPerSec float64 + var nsPerTokenStr, tokensPerSecStr string + + if m.Step == "generate" || m.Step == "prefill" { + nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count) + tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 + nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken) + tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec) + } else { + nsPerTokenStr = "-" + tokensPerSecStr = "-" + } + + fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n", + m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr) + } + default: + fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format) + } +} + +func BenchmarkChat(fOpt flagOptions) error { + models := strings.Split(*fOpt.models, ",") + + // todo - add multi-image support + var imgData api.ImageData + var err error + if *fOpt.imageFile != "" { + imgData, err = readImage(*fOpt.imageFile) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err) + return err + } + } + + if *fOpt.debug && imgData != nil { + fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile) + } + + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err) + return err + } + + for _, model := range models { + for range *fOpt.epochs { + options := make(map[string]interface{}) + if *fOpt.maxTokens > 0 { + options["num_predict"] = *fOpt.maxTokens + } + options["temperature"] = *fOpt.temperature + if fOpt.seed != nil && *fOpt.seed > 0 { + options["seed"] = *fOpt.seed + } + + var keepAliveDuration *api.Duration + if *fOpt.keepAlive > 0 { + duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))} + keepAliveDuration = &duration + } + + req := &api.ChatRequest{ + Model: model, + Messages: []api.Message{ + { + Role: "user", + Content: *fOpt.prompt, + }, + }, + Options: options, + KeepAlive: keepAliveDuration, + } + + if imgData != nil { + req.Messages[0].Images = []api.ImageData{imgData} + } + + var responseMetrics *api.Metrics + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) + defer cancel() + + err = client.Chat(ctx, req, func(resp api.ChatResponse) error { + if *fOpt.debug { + fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content)) + } + + if resp.Done { + responseMetrics = &resp.Metrics + } + return nil + }) + + if *fOpt.debug { + fmt.Fprintln(os.Stderr) + } + + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1) + continue + } + fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err) + continue + } + + if responseMetrics == nil { + fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model) + continue + } + + metrics := []Metrics{ + { + Model: model, + Step: "prefill", + Count: responseMetrics.PromptEvalCount, + Duration: responseMetrics.PromptEvalDuration, + }, + { + Model: model, + Step: "generate", + Count: responseMetrics.EvalCount, + Duration: responseMetrics.EvalDuration, + }, + { + Model: model, + Step: "load", + Count: 1, + Duration: responseMetrics.LoadDuration, + }, + { + Model: model, + Step: "total", + Count: 1, + Duration: responseMetrics.TotalDuration, + }, + } + + OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose) + + if *fOpt.keepAlive > 0 { + time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond) + } + } + } + return nil +} + +func readImage(filePath string) (api.ImageData, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + return api.ImageData(data), nil +} + +func main() { + fOpt := flagOptions{ + models: flag.String("model", "", "Model to benchmark"), + epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"), + maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"), + temperature: flag.Float64("temperature", 0, "Temperature parameter"), + seed: flag.Int("seed", 0, "Random seed"), + timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"), + prompt: flag.String("p", DefaultPrompt, "Prompt to use"), + imageFile: flag.String("image", "", "Filename for an image to include"), + keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"), + format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"), + outputFile: flag.String("output", "", "Output file for results (stdout if empty)"), + verbose: flag.Bool("v", false, "Show system information"), + debug: flag.Bool("debug", false, "Show debug information"), + } + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Description:\n") + fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n") + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nExamples:\n") + fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n") + } + flag.Parse() + + if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) { + fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format) + os.Exit(1) + } + + if len(*fOpt.models) == 0 { + fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n") + flag.Usage() + return + } + + BenchmarkChat(fOpt) +} diff --git a/cmd/bench/bench_test.go b/cmd/bench/bench_test.go new file mode 100644 index 00000000..bcd282d7 --- /dev/null +++ b/cmd/bench/bench_test.go @@ -0,0 +1,463 @@ +package main + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func createTestFlagOptions() flagOptions { + models := "test-model" + format := "benchstat" + epochs := 1 + maxTokens := 100 + temperature := 0.7 + seed := 42 + timeout := 30 + prompt := "test prompt" + imageFile := "" + keepAlive := 5.0 + verbose := false + debug := false + + return flagOptions{ + models: &models, + format: &format, + epochs: &epochs, + maxTokens: &maxTokens, + temperature: &temperature, + seed: &seed, + timeout: &timeout, + prompt: &prompt, + imageFile: &imageFile, + keepAlive: &keepAlive, + verbose: &verbose, + debug: &debug, + } +} + +func captureOutput(f func()) string { + oldStdout := os.Stdout + oldStderr := os.Stderr + defer func() { + os.Stdout = oldStdout + os.Stderr = oldStderr + }() + + r, w, _ := os.Pipe() + os.Stdout = w + os.Stderr = w + + f() + + w.Close() + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/chat" { + t.Errorf("Expected path /api/chat, got %s", r.URL.Path) + http.Error(w, "Not found", http.StatusNotFound) + return + } + + if r.Method != "POST" { + t.Errorf("Expected POST method, got %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + for _, resp := range responses { + jsonData, err := json.Marshal(resp) + if err != nil { + t.Errorf("Failed to marshal response: %v", err) + return + } + w.Write(jsonData) + w.Write([]byte("\n")) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(10 * time.Millisecond) // Simulate some delay + } + })) +} + +func TestBenchmarkChat_Success(t *testing.T) { + fOpt := createTestFlagOptions() + + mockResponses := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "test response part 1", + }, + Done: false, + }, + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "test response part 2", + }, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + }, + } + + server := createMockOllamaServer(t, mockResponses) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") { + t.Errorf("Expected output to contain prefill metrics, got: %s", output) + } + if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") { + t.Errorf("Expected output to contain generate metrics, got: %s", output) + } + if !strings.Contains(output, "ns/token") { + t.Errorf("Expected output to contain ns/token metric, got: %s", output) + } +} + +func TestBenchmarkChat_ServerError(t *testing.T) { + fOpt := createTestFlagOptions() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal server error", http.StatusInternalServerError) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected error to be handled internally, got returned error: %v", err) + } + }) + + if !strings.Contains(output, "ERROR: Couldn't chat with model") { + t.Errorf("Expected error message about chat failure, got: %s", output) + } +} + +func TestBenchmarkChat_Timeout(t *testing.T) { + fOpt := createTestFlagOptions() + shortTimeout := 1 // Very short timeout + fOpt.timeout = &shortTimeout + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate a long delay that will cause timeout + time.Sleep(2 * time.Second) + + w.Header().Set("Content-Type", "application/json") + response := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "test response", + }, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected timeout to be handled internally, got returned error: %v", err) + } + }) + + if !strings.Contains(output, "ERROR: Chat request timed out") { + t.Errorf("Expected timeout error message, got: %s", output) + } +} + +func TestBenchmarkChat_NoMetrics(t *testing.T) { + fOpt := createTestFlagOptions() + + mockResponses := []api.ChatResponse{ + { + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "test response", + }, + Done: false, // Never sends Done=true + }, + } + + server := createMockOllamaServer(t, mockResponses) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "ERROR: No metrics received") { + t.Errorf("Expected no metrics error message, got: %s", output) + } +} + +func TestBenchmarkChat_MultipleModels(t *testing.T) { + fOpt := createTestFlagOptions() + models := "model1,model2" + epochs := 2 + fOpt.models = &models + fOpt.epochs = &epochs + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + + w.Header().Set("Content-Type", "application/json") + + var req api.ChatRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + response := api.ChatResponse{ + Model: req.Model, + Message: api.Message{ + Role: "assistant", + Content: "test response for " + req.Model, + }, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // Should be called 4 times (2 models × 2 epochs) + if callCount != 4 { + t.Errorf("Expected 4 API calls, got %d", callCount) + } + + if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") { + t.Errorf("Expected output for both models, got: %s", output) + } +} + +func TestBenchmarkChat_WithImage(t *testing.T) { + fOpt := createTestFlagOptions() + + tmpfile, err := os.CreateTemp(t.TempDir(), "testimage") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("fake image data") + if _, err := tmpfile.Write(content); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + tmpfile.Close() + + tmpfileName := tmpfile.Name() + fOpt.imageFile = &tmpfileName + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request contains image data + var req api.ChatRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 { + t.Error("Expected request to contain images") + } + + w.Header().Set("Content-Type", "application/json") + response := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "test response with image", + }, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "BenchmarkModel/name=test-model") { + t.Errorf("Expected benchmark output, got: %s", output) + } +} + +func TestBenchmarkChat_ImageError(t *testing.T) { + randFileName := func() string { + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + const length = 8 + + result := make([]byte, length) + rand.Read(result) // Fill with random bytes + + for i := range result { + result[i] = charset[result[i]%byte(len(charset))] + } + + return string(result) + ".txt" + } + + fOpt := createTestFlagOptions() + imageFile := randFileName() + fOpt.imageFile = &imageFile + + output := captureOutput(func() { + err := BenchmarkChat(fOpt) + if err == nil { + t.Error("Expected error from image reading, got nil") + } + }) + + if !strings.Contains(output, "ERROR: Couldn't read image") { + t.Errorf("Expected image read error message, got: %s", output) + } +} + +func TestReadImage_Success(t *testing.T) { + tmpfile, err := os.CreateTemp(t.TempDir(), "testimage") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("fake image data") + if _, err := tmpfile.Write(content); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + tmpfile.Close() + + imgData, err := readImage(tmpfile.Name()) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if imgData == nil { + t.Error("Expected image data, got nil") + } + + expected := api.ImageData(content) + if string(imgData) != string(expected) { + t.Errorf("Expected image data %v, got %v", expected, imgData) + } +} + +func TestReadImage_FileNotFound(t *testing.T) { + imgData, err := readImage("nonexistentfile.jpg") + if err == nil { + t.Error("Expected error for non-existent file, got nil") + } + if imgData != nil { + t.Error("Expected nil image data for non-existent file") + } +} + +func TestOptionsMapCreation(t *testing.T) { + fOpt := createTestFlagOptions() + + options := make(map[string]interface{}) + if *fOpt.maxTokens > 0 { + options["num_predict"] = *fOpt.maxTokens + } + options["temperature"] = *fOpt.temperature + if fOpt.seed != nil && *fOpt.seed > 0 { + options["seed"] = *fOpt.seed + } + + if options["num_predict"] != *fOpt.maxTokens { + t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"]) + } + if options["temperature"] != *fOpt.temperature { + t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"]) + } + if options["seed"] != *fOpt.seed { + t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"]) + } +}