server: Consolidate embedding truncation in runner (#12730)

Currently, checking the length of prompts for embeddings to ensure
they fit in the context window (and possible truncation) occurs in
two places - the Ollama server and runner. This can lead to
inconsistencies in both the checks and reported number of tokens
processed. Since we have to do this processing in the runner, this
consolidates all of the logic there.
This commit is contained in:
nicole pardal
2025-10-27 11:59:12 -07:00
committed by GitHub
parent b97eb2b858
commit 5d347f6d6f
6 changed files with 264 additions and 84 deletions

View File

@@ -4,7 +4,9 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@@ -299,3 +301,216 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
t.Run("single input token count", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("batch parallel token counting", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"cat", "dog and mouse", "bird"},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("truncation single input", func(t *testing.T) {
truncTrue := true
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 50},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount > 50 {
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
}
if res.PromptEvalCount == 0 {
t.Fatal("expected non-zero token count after truncation")
}
})
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("truncate false error", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: strings.Repeat("word ", 100),
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
if !strings.Contains(err.Error(), "exceeds maximum context length") {
t.Fatalf("expected context length error, got: %v", err)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull the model if needed
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
}

View File

@@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error)
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbeddingRequest struct {
Content string `json:"content"`
Content string `json:"content"`
Truncate bool `json:"truncate"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
}
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil {
@@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return nil, err
return nil, 0, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return nil, err
return nil, 0, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status)
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input})
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err)
return nil, 0, fmt.Errorf("do embedding request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading embed response: %w", err)
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body)
return nil, fmt.Errorf("%s", body)
return nil, 0, api.StatusError{
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
}
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return e.Embedding, nil
return e.Embedding, e.PromptEvalCount, nil
}
type TokenizeRequest struct {

View File

@@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@@ -758,7 +758,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
Embedding: embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@@ -995,7 +995,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
Embedding: <-seq.embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@@ -21,6 +21,7 @@ import (
"os/signal"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"
@@ -659,7 +660,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -672,61 +673,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input {
g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text)
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
if err != nil {
return err
}
@@ -736,12 +688,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil
})
}
if err := g.Wait(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
} else {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
}
return
}
@@ -750,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: count,
PromptEvalCount: int(totalTokens),
}
c.JSON(http.StatusOK, resp)
}
@@ -796,7 +754,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return

View File

@@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embeddingResp, s.embeddingRespErr
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
return s.embeddingResp, 0, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {