From 6745182885e798be81c70f277eaea81cc4f71524 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 9 Sep 2025 09:32:15 -0700 Subject: [PATCH] tests: reduce stress on CPU to 2 models (#12161) * tests: reduce stress on CPU to 2 models This should avoid flakes due to systems getting overloaded with 3 (or more) models running concurrently * tests: allow slow systems to pass on timeout If a slow system is still streaming a response, and the response will pass validation, don't fail just because the system is slow. * test: unload embedding models more quickly --- integration/concurrency_test.go | 8 ++++ integration/embed_test.go | 5 ++- integration/utils_test.go | 72 ++++++++++++++++++++------------- 3 files changed, 55 insertions(+), 30 deletions(-) diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 331bb6e7..3104eacc 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) { // The intent is to go 1 over what can fit so we force the scheduler to thrash targetLoadCount := 0 slog.Info("Loading models to find how many can fit in VRAM before overflowing") +chooseModels: for i, model := range chosenModels { req := &api.GenerateRequest{Model: model} slog.Info("loading", "model", model) @@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) { slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount]) break } + // Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts + for _, m := range models.Models { + if m.SizeVRAM == 0 { + slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount]) + break chooseModels + } + } } } if targetLoadCount == len(chosenModels) { diff --git a/integration/embed_test.go b/integration/embed_test.go index 09369dbb..eb00f4ba 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) { defer cleanup() req := api.EmbeddingRequest{ - Model: "all-minilm", - Prompt: "why is the sky blue?", + Model: "all-minilm", + Prompt: "why is the sky blue?", + KeepAlive: &api.Duration{Duration: 10 * time.Second}, } res, err := embeddingTestHelper(ctx, client, t, req) diff --git a/integration/utils_test.go b/integration/utils_test.go index 2bb6a157..ec74b2e3 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap done <- 0 }() + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response) + } + } + select { case <-stallTimer.C: if buf.Len() == 0 { @@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap if genErr != nil { t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt) } - // Verify the response contains the expected data - response := buf.String() - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(response), resp) { - atLeastOne = true - break - } - } - if !atLeastOne { - t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response) - } + verify() slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) case <-ctx.Done(): - t.Error("outer test context done while waiting for generate") + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for generate") + verify() } return context } @@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR done <- 0 }() + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) + } + } + select { case <-stallTimer.C: if buf.Len() == 0 { @@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR if genErr != nil { t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages) } - - // Verify the response contains the expected data - response := buf.String() - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(response), resp) { - atLeastOne = true - break - } - } - if !atLeastOne { - t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) - } - + verify() slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response) case <-ctx.Done(): - t.Error("outer test context done while waiting for generate") + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for chat") + verify() } return &api.Message{Role: role, Content: buf.String()} }