truncation: fixed runner truncation logic + removed server truncation (#12839)

This PR consolidates all embedding prompt-length checking, truncation, and prompt token counting into the runner to ensure a single source of truth.
This commit is contained in:
nicole pardal
2025-12-08 11:20:28 -08:00
committed by GitHub
parent 5dae738067
commit e082d60a24
6 changed files with 278 additions and 88 deletions

View File

@@ -4,7 +4,9 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
}
if res.PromptEvalCount != 6 {
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 8 {
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
}
if res.PromptEvalCount != 12 {
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 16 {
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
cases := []struct {
name string
request api.EmbedRequest
check func(*api.EmbedResponse, error)
check func(*testing.T, *api.EmbedResponse, error)
}{
{
name: "target truncation",
@@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Model: "all-minilm",
Input: "why",
},
check: func(got *api.EmbedResponse, err error) {
check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "the input length exceeds the context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
name: "input after truncate error with context length of 1",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
check: func(res *api.EmbedResponse, err error) {
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(res *api.EmbedResponse, err error) {
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue? Why is the sky blue? hi there my",
Options: map[string]any{"num_ctx": 16},
},
check: func(res *api.EmbedResponse, err error) {
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
req.check(embedTestHelper(ctx, client, t, req.request))
resp, err := embedTestHelper(ctx, client, t, req.request)
req.check(t, resp, err)
})
}
}
@@ -409,3 +414,173 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
// Give each model its own budget to account for first-time pulls/loads
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(mctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(mctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
// Pull the model if needed
if err := PullIfMissing(mctx, client, model); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: model,
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: model,
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
})
}
}