mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
perf: build graph for next batch async to keep GPU busy (#11863)
* perf: build graph for next batch in parallel to keep GPU busy This refactors the main run loop of the ollama runner to perform the main GPU intensive tasks (Compute+Floats) in a go routine so we can prepare the next batch in parallel to reduce the amount of time the GPU stalls waiting for the next batch of work. * tests: tune integration tests for ollama engine This tunes the integration tests to focus more on models supported by the new engine.
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
@@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Model: libraryEmbedModels[0],
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlueSky(t *testing.T) {
|
||||
@@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
@@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
@@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// All models compatible with ollama-engine
|
||||
smallModels := []string{
|
||||
"llama3.2:1b",
|
||||
"qwen3:0.6b",
|
||||
"gemma:2b",
|
||||
"deepseek-r1:1.5b",
|
||||
"starcoder2:3b",
|
||||
"gemma2:2b",
|
||||
"deepseek-r1:1.5b", // qwen2 arch
|
||||
"gemma3:270m",
|
||||
}
|
||||
mediumModels := []string{
|
||||
"qwen3:8b",
|
||||
"llama2",
|
||||
"deepseek-r1:7b",
|
||||
"mistral",
|
||||
"dolphin-mistral",
|
||||
"gemma:7b",
|
||||
"codellama:7b",
|
||||
"llama3.2:3b", // ~3.4G
|
||||
"qwen3:8b", // ~6.6G
|
||||
"gpt-oss:20b", // ~15G
|
||||
"deepseek-r1:7b", // ~5.6G
|
||||
"gemma3:4b", // ~5.8G
|
||||
"gemma2:9b", // ~8.1G
|
||||
}
|
||||
|
||||
var chosenModels []string
|
||||
@@ -114,7 +112,9 @@ func TestMultiModelStress(t *testing.T) {
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, model := range chosenModels {
|
||||
require.NoError(t, PullIfMissing(ctx, client, model))
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine how many models we can load in parallel before we exceed VRAM
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -49,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
@@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVisionModels(t *testing.T) {
|
||||
@@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
@@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
@@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
@@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
@@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
@@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
switch {
|
||||
case genErr == nil:
|
||||
successCount++
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||
}
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
@@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%d request failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
@@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedwg.Wait()
|
||||
|
||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
if resetByPeerCount != 0 {
|
||||
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||
}
|
||||
if busyCount == 0 {
|
||||
t.Fatalf("no requests hit busy error but some should have")
|
||||
}
|
||||
if canceledCount > 0 {
|
||||
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -25,11 +26,11 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
smol = "llama3.2:1b"
|
||||
smol = "llama3.2:1b"
|
||||
stream = false
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
@@ -468,7 +471,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
@@ -509,7 +514,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return context
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
@@ -519,7 +526,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
@@ -561,17 +570,97 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[][]string{
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
// think := api.ThinkValue{Value: "low"}
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
// reqs[i].Think = &think
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
@@ -579,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
gpuPercent = 0
|
||||
case m.SizeVRAM == m.Size:
|
||||
gpuPercent = 100
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
|
||||
Reference in New Issue
Block a user