From 5d347f6d6f6813895412846a240fc046d49c4817 Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:59:12 -0700 Subject: [PATCH] server: Consolidate embedding truncation in runner (#12730) Currently, checking the length of prompts for embeddings to ensure they fit in the context window (and possible truncation) occurs in two places - the Ollama server and runner. This can lead to inconsistencies in both the checks and reported number of tokens processed. Since we have to do this processing in the runner, this consolidates all of the logic there. --- integration/embed_test.go | 215 ++++++++++++++++++++++++++++++++++ llm/server.go | 35 +++--- runner/llamarunner/runner.go | 13 +- runner/ollamarunner/runner.go | 13 +- server/routes.go | 68 ++--------- server/sched_test.go | 4 +- 6 files changed, 264 insertions(+), 84 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 3a8bcd24..432df9ab 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "errors" "math" + "strings" "testing" "time" @@ -299,3 +301,216 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req return client.Embed(ctx, &req) } + +func TestEmbedTruncation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + t.Run("single input token count", func(t *testing.T) { + req := api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if res.PromptEvalCount <= 0 { + t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) + } + }) + + t.Run("batch parallel token counting", func(t *testing.T) { + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"cat", "dog and mouse", "bird"}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount <= 0 { + t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) + } + }) + + t.Run("truncation single input", func(t *testing.T) { + truncTrue := true + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: longInput, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 50}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if res.PromptEvalCount > 50 { + t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount) + } + + if res.PromptEvalCount == 0 { + t.Fatal("expected non-zero token count after truncation") + } + }) + + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) + + t.Run("truncate false error", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: "all-minilm", + Input: strings.Repeat("word ", 100), + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + if !strings.Contains(err.Error(), "exceeds maximum context length") { + t.Fatalf("expected context length error, got: %v", err) + } + }) + + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"} + baseRes, err := embedTestHelper(ctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } + + batch := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(ctx, client, t, batch) + if err != nil { + t.Fatal(err) + } + + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) +} + +// TestEmbedStatusCode tests that errors from the embedding endpoint +// properly preserve their HTTP status codes when returned to the client. +// This test specifically checks the error handling path in EmbedHandler +// where api.StatusError errors should maintain their original status code. +func TestEmbedStatusCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Pull the model if needed + if err := PullIfMissing(ctx, client, "all-minilm"); err != nil { + t.Fatal(err) + } + + t.Run("truncation error status code", func(t *testing.T) { + truncFalse := false + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: longInput, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error (likely 400 Bad Request) + // not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + + // Verify the error message is meaningful + if !strings.Contains(err.Error(), "context length") { + t.Errorf("expected error message to mention context length, got: %v", err) + } + }) + + t.Run("batch truncation error status code", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{ + "short input", + strings.Repeat("very long input ", 100), + "another short input", + }, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when one input exceeds context with truncate=false") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error, not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + }) +} diff --git a/llm/server.go b/llm/server.go index f8b232df..302b7a01 100644 --- a/llm/server.go +++ b/llm/server.go @@ -69,7 +69,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, input string) ([]float32, error) + Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } type EmbeddingRequest struct { - Content string `json:"content"` + Content string `json:"content"` + Truncate bool `json:"truncate"` } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding []float32 `json:"embedding"` + PromptEvalCount int `json:"prompt_eval_count"` } -func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { +func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { @@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err } else { slog.Error("Failed to acquire semaphore", "error", err) } - return nil, err + return nil, 0, err } defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { - return nil, err + return nil, 0, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status) + return nil, 0, fmt.Errorf("unexpected server status: %s", status) } - data, err := json.Marshal(EmbeddingRequest{Content: input}) + data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate}) if err != nil { - return nil, fmt.Errorf("error marshaling embed data: %w", err) + return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { - return nil, fmt.Errorf("error creating embed request: %w", err) + return nil, 0, fmt.Errorf("error creating embed request: %w", err) } r.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("do embedding request: %w", err) + return nil, 0, fmt.Errorf("do embedding request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading embed response: %w", err) + return nil, 0, fmt.Errorf("error reading embed response: %w", err) } if resp.StatusCode >= 400 { log.Printf("llm embedding error: %s", body) - return nil, fmt.Errorf("%s", body) + return nil, 0, api.StatusError{ + StatusCode: resp.StatusCode, + ErrorMessage: string(body), + } } var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { - return nil, fmt.Errorf("unmarshal tokenize response: %w", err) + return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) } - return e.Embedding, nil + return e.Embedding, e.PromptEvalCount, nil } type TokenizeRequest struct { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 87b43256..be899b3b 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -758,7 +758,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: embedding, + Embedding: embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 7b72bf92..b0cf6373 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -995,7 +995,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: <-seq.embedding, + Embedding: <-seq.embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/server/routes.go b/server/routes.go index 3d32a9aa..99a64d19 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,6 +21,7 @@ import ( "os/signal" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -659,7 +660,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -672,61 +673,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - kvData, _, err := getModelData(m.ModelPath, false) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - var count int - for i, s := range input { - tokens, err := r.Tokenize(c.Request.Context(), s) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) - return - } - - if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { - ctxLen-- - } - - if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { - ctxLen-- - } - - slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) - if ctxLen <= 0 { - // return error if the truncated input would be empty or just special tokens - c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) - return - } - - tokens = tokens[:ctxLen] - - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - - count += len(tokens) - - input[i] = s - } - var g errgroup.Group embeddings := make([][]float32, len(input)) + var totalTokens uint64 for i, text := range input { g.Go(func() error { - embedding, err := r.Embedding(c.Request.Context(), text) + embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate) if err != nil { return err } @@ -736,12 +688,18 @@ func (s *Server) EmbedHandler(c *gin.Context) { embedding = normalize(embedding[:req.Dimensions]) } embeddings[i] = embedding + atomic.AddUint64(&totalTokens, uint64(tokenCount)) return nil }) } if err := g.Wait(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + var serr api.StatusError + if errors.As(err, &serr) { + c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)}) + } else { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + } return } @@ -750,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: count, + PromptEvalCount: int(totalTokens), } c.JSON(http.StatusOK, resp) } @@ -796,7 +754,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return diff --git a/server/sched_test.go b/server/sched_test.go index 316a817f..c531e7eb 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { + return s.embeddingResp, 0, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {