From 29f63f37c87e2c5a908bfb6b2c8b3320052e0bbe Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 28 Oct 2025 14:49:14 -0700 Subject: [PATCH] Revert "server: Consolidate embedding truncation in runner (#12730)" (#12810) This reverts commit 5d347f6d6f6813895412846a240fc046d49c4817. --- integration/embed_test.go | 215 ---------------------------------- llm/server.go | 35 +++--- runner/llamarunner/runner.go | 13 +- runner/ollamarunner/runner.go | 13 +- server/routes.go | 68 +++++++++-- server/sched_test.go | 4 +- 6 files changed, 84 insertions(+), 264 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 432df9ab..3a8bcd24 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,9 +4,7 @@ package integration import ( "context" - "errors" "math" - "strings" "testing" "time" @@ -301,216 +299,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req return client.Embed(ctx, &req) } - -func TestEmbedTruncation(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) - defer cancel() - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() - - t.Run("single input token count", func(t *testing.T) { - req := api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if res.PromptEvalCount <= 0 { - t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) - } - }) - - t.Run("batch parallel token counting", func(t *testing.T) { - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"cat", "dog and mouse", "bird"}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if len(res.Embeddings) != 3 { - t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) - } - - if res.PromptEvalCount <= 0 { - t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) - } - }) - - t.Run("truncation single input", func(t *testing.T) { - truncTrue := true - longInput := strings.Repeat("word ", 100) - - req := api.EmbedRequest{ - Model: "all-minilm", - Input: longInput, - Truncate: &truncTrue, - Options: map[string]any{"num_ctx": 50}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if res.PromptEvalCount > 50 { - t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount) - } - - if res.PromptEvalCount == 0 { - t.Fatal("expected non-zero token count after truncation") - } - }) - - t.Run("truncation batch", func(t *testing.T) { - truncTrue := true - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, - Truncate: &truncTrue, - Options: map[string]any{"num_ctx": 30}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if len(res.Embeddings) != 3 { - t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) - } - - if res.PromptEvalCount > 90 { - t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) - } - }) - - t.Run("truncate false error", func(t *testing.T) { - truncFalse := false - req := api.EmbedRequest{ - Model: "all-minilm", - Input: strings.Repeat("word ", 100), - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when truncate=false with long input") - } - - if !strings.Contains(err.Error(), "exceeds maximum context length") { - t.Fatalf("expected context length error, got: %v", err) - } - }) - - t.Run("runner token count accuracy", func(t *testing.T) { - baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"} - baseRes, err := embedTestHelper(ctx, client, t, baseline) - if err != nil { - t.Fatal(err) - } - - batch := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"test", "test", "test"}, - } - batchRes, err := embedTestHelper(ctx, client, t, batch) - if err != nil { - t.Fatal(err) - } - - expectedCount := baseRes.PromptEvalCount * 3 - if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { - t.Fatalf("expected ~%d tokens (3 × %d), got %d", - expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) - } - }) -} - -// TestEmbedStatusCode tests that errors from the embedding endpoint -// properly preserve their HTTP status codes when returned to the client. -// This test specifically checks the error handling path in EmbedHandler -// where api.StatusError errors should maintain their original status code. -func TestEmbedStatusCode(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) - defer cancel() - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() - - // Pull the model if needed - if err := PullIfMissing(ctx, client, "all-minilm"); err != nil { - t.Fatal(err) - } - - t.Run("truncation error status code", func(t *testing.T) { - truncFalse := false - longInput := strings.Repeat("word ", 100) - - req := api.EmbedRequest{ - Model: "all-minilm", - Input: longInput, - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when truncate=false with long input") - } - - // Check that it's a StatusError with the correct status code - var statusErr api.StatusError - if !errors.As(err, &statusErr) { - t.Fatalf("expected api.StatusError, got %T: %v", err, err) - } - - // The error should be a 4xx client error (likely 400 Bad Request) - // not a 500 Internal Server Error - if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { - t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) - } - - // Verify the error message is meaningful - if !strings.Contains(err.Error(), "context length") { - t.Errorf("expected error message to mention context length, got: %v", err) - } - }) - - t.Run("batch truncation error status code", func(t *testing.T) { - truncFalse := false - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{ - "short input", - strings.Repeat("very long input ", 100), - "another short input", - }, - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when one input exceeds context with truncate=false") - } - - // Check that it's a StatusError with the correct status code - var statusErr api.StatusError - if !errors.As(err, &statusErr) { - t.Fatalf("expected api.StatusError, got %T: %v", err, err) - } - - // The error should be a 4xx client error, not a 500 Internal Server Error - if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { - t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) - } - }) -} diff --git a/llm/server.go b/llm/server.go index 302b7a01..f8b232df 100644 --- a/llm/server.go +++ b/llm/server.go @@ -69,7 +69,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) + Embedding(ctx context.Context, input string) ([]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -1545,16 +1545,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } type EmbeddingRequest struct { - Content string `json:"content"` - Truncate bool `json:"truncate"` + Content string `json:"content"` } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` - PromptEvalCount int `json:"prompt_eval_count"` + Embedding []float32 `json:"embedding"` } -func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { +func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { @@ -1563,54 +1561,51 @@ func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) } else { slog.Error("Failed to acquire semaphore", "error", err) } - return nil, 0, err + return nil, err } defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { - return nil, 0, err + return nil, err } else if status != ServerStatusReady { - return nil, 0, fmt.Errorf("unexpected server status: %s", status) + return nil, fmt.Errorf("unexpected server status: %s", status) } - data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate}) + data, err := json.Marshal(EmbeddingRequest{Content: input}) if err != nil { - return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) + return nil, fmt.Errorf("error marshaling embed data: %w", err) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { - return nil, 0, fmt.Errorf("error creating embed request: %w", err) + return nil, fmt.Errorf("error creating embed request: %w", err) } r.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, 0, fmt.Errorf("do embedding request: %w", err) + return nil, fmt.Errorf("do embedding request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, 0, fmt.Errorf("error reading embed response: %w", err) + return nil, fmt.Errorf("error reading embed response: %w", err) } if resp.StatusCode >= 400 { log.Printf("llm embedding error: %s", body) - return nil, 0, api.StatusError{ - StatusCode: resp.StatusCode, - ErrorMessage: string(body), - } + return nil, fmt.Errorf("%s", body) } var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { - return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) + return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return e.Embedding, e.PromptEvalCount, nil + return e.Embedding, nil } type TokenizeRequest struct { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index be899b3b..87b43256 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - truncate: req.Truncate, + + // TODO (jmorganca): this should be provided by the server via the + // request options and truncated here in the runner, instead of relying on + // the server's truncate logic + truncate: true, }) if err != nil { - if errors.Is(err, errorInputTooLong) { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -758,8 +758,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: embedding, - PromptEvalCount: seq.numPromptInputs, + Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index ea039157..e977d18f 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - truncate: req.Truncate, + + // TODO (jmorganca): this should be provided by the server via the + // request options and truncated here in the runner, instead of relying on + // the server's truncate logic + truncate: true, }) if err != nil { - if errors.Is(err, errorInputTooLong) { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -995,8 +995,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: <-seq.embedding, - PromptEvalCount: seq.numPromptInputs, + Embedding: <-seq.embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/server/routes.go b/server/routes.go index 99a64d19..3d32a9aa 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,7 +21,6 @@ import ( "os/signal" "slices" "strings" - "sync/atomic" "syscall" "time" @@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + kvData, _, err := getModelData(m.ModelPath, false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var count int + for i, s := range input { + tokens, err := r.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if !truncate { + c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) + return + } + + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + + slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) + if ctxLen <= 0 { + // return error if the truncated input would be empty or just special tokens + c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) + return + } + + tokens = tokens[:ctxLen] + + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + count += len(tokens) + + input[i] = s + } + var g errgroup.Group embeddings := make([][]float32, len(input)) - var totalTokens uint64 for i, text := range input { g.Go(func() error { - embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate) + embedding, err := r.Embedding(c.Request.Context(), text) if err != nil { return err } @@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { embedding = normalize(embedding[:req.Dimensions]) } embeddings[i] = embedding - atomic.AddUint64(&totalTokens, uint64(tokenCount)) return nil }) } if err := g.Wait(); err != nil { - var serr api.StatusError - if errors.As(err, &serr) { - c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)}) - } else { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) - } + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return } @@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: int(totalTokens), + PromptEvalCount: count, } c.JSON(http.StatusOK, resp) } @@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true) + embedding, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return diff --git a/server/sched_test.go b/server/sched_test.go index c531e7eb..316a817f 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { - return s.embeddingResp, 0, s.embeddingRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { + return s.embeddingResp, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {