test: harden scheduler tests (#12662)

* test: harden scheduler tests

This removes reschedDelay which was stale code, and adds
a new configurable timeout for the waitForVRAMRecovery so
tests can now set the timeout to be very short to avoid the
scheduler getting stuck and hitting a test timeout.

* test: tune tests for partial loads

Give stress tests more time when the model is split between CPU/GPU
This commit is contained in:
Daniel Hiltgen
2025-10-17 08:56:44 -07:00
committed by GitHub
parent 270679932f
commit 68e04c7ff8
10 changed files with 195 additions and 143 deletions

View File

@@ -109,6 +109,8 @@ func TestMultiModelStress(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Make sure all the models are pulled before we get started
for _, model := range chosenModels {
@@ -147,6 +149,8 @@ chooseModels:
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
break chooseModels
}
}
@@ -172,10 +176,7 @@ chooseModels:
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
DoChat(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)
DoChat(ctx, t, client, reqs[k], resps[k], initialTimeout, streamTimeout)
}
}(i)
}

View File

@@ -78,7 +78,7 @@ func TestContextExhaustion(t *testing.T) {
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestParallelGenerateWithHistory(t *testing.T) {
modelOverride := "gpt-oss:20b"
modelName := "gpt-oss:20b"
req, resp := GenerateRequests()
numParallel := 2
iterLimit := 2
@@ -88,15 +88,23 @@ func TestParallelGenerateWithHistory(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", modelOverride)
slog.Info("loading", "model", modelName)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
t.Fatalf("failed to load model %s: %s", modelName, err)
}
gpuPercent := getGPUPercent(ctx, t, client, modelName)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
}
var wg sync.WaitGroup
@@ -105,7 +113,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
req[k].Model = modelName
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -114,7 +122,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
c := DoGenerate(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
req[k].Context = c
req[k].Prompt = "tell me more!"
}
@@ -165,7 +173,7 @@ func TestGenerateWithHistory(t *testing.T) {
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestParallelChatWithHistory(t *testing.T) {
modelOverride := "gpt-oss:20b"
modelName := "gpt-oss:20b"
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
@@ -175,15 +183,23 @@ func TestParallelChatWithHistory(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelOverride)
slog.Info("loading", "model", modelName)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
t.Fatalf("failed to load model %s: %s", modelName, err)
}
gpuPercent := getGPUPercent(ctx, t, client, modelName)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
}
var wg sync.WaitGroup
@@ -192,7 +208,7 @@ func TestParallelChatWithHistory(t *testing.T) {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
req[k].Model = modelName
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -201,7 +217,7 @@ func TestParallelChatWithHistory(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
assistant := DoChat(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}

View File

@@ -65,6 +65,23 @@ func TestModelsChat(t *testing.T) {
}
}
}
initialTimeout := 120 * time.Second
streamTimeout := 30 * time.Second
slog.Info("loading", "model", model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", model, err)
}
gpuPercent := getGPUPercent(ctx, t, client, model)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 40 * time.Second
}
// TODO - fiddle with context size
req := api.ChatRequest{
Model: model,
@@ -80,7 +97,7 @@ func TestModelsChat(t *testing.T) {
"seed": 123,
},
}
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
DoChat(ctx, t, client, req, blueSkyExpected, initialTimeout, streamTimeout)
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
})

View File

@@ -743,6 +743,13 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
gpuPercent := getGPUPercent(ctx, t, client, model)
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
}
func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model string) int {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
@@ -772,12 +779,10 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
gpuPercent = int(100 - cpuPercent)
}
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
return
return gpuPercent
}
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
t.Fatalf("model %s not loaded - actually loaded: %v", model, loaded)
return 0
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {