diff --git a/integration/embed_test.go b/integration/embed_test.go index 3a8bcd24..432df9ab 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "errors" "math" + "strings" "testing" "time" @@ -299,3 +301,216 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req return client.Embed(ctx, &req) } + +func TestEmbedTruncation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + t.Run("single input token count", func(t *testing.T) { + req := api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if res.PromptEvalCount <= 0 { + t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) + } + }) + + t.Run("batch parallel token counting", func(t *testing.T) { + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"cat", "dog and mouse", "bird"}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount <= 0 { + t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) + } + }) + + t.Run("truncation single input", func(t *testing.T) { + truncTrue := true + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: longInput, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 50}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if res.PromptEvalCount > 50 { + t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount) + } + + if res.PromptEvalCount == 0 { + t.Fatal("expected non-zero token count after truncation") + } + }) + + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } + + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) + + t.Run("truncate false error", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: "all-minilm", + Input: strings.Repeat("word ", 100), + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + if !strings.Contains(err.Error(), "exceeds maximum context length") { + t.Fatalf("expected context length error, got: %v", err) + } + }) + + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"} + baseRes, err := embedTestHelper(ctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } + + batch := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(ctx, client, t, batch) + if err != nil { + t.Fatal(err) + } + + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) +} + +// TestEmbedStatusCode tests that errors from the embedding endpoint +// properly preserve their HTTP status codes when returned to the client. +// This test specifically checks the error handling path in EmbedHandler +// where api.StatusError errors should maintain their original status code. +func TestEmbedStatusCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Pull the model if needed + if err := PullIfMissing(ctx, client, "all-minilm"); err != nil { + t.Fatal(err) + } + + t.Run("truncation error status code", func(t *testing.T) { + truncFalse := false + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: longInput, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error (likely 400 Bad Request) + // not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + + // Verify the error message is meaningful + if !strings.Contains(err.Error(), "context length") { + t.Errorf("expected error message to mention context length, got: %v", err) + } + }) + + t.Run("batch truncation error status code", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{ + "short input", + strings.Repeat("very long input ", 100), + "another short input", + }, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when one input exceeds context with truncate=false") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error, not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + }) +} diff --git a/llm/server.go b/llm/server.go index f8b232df..302b7a01 100644 --- a/llm/server.go +++ b/llm/server.go @@ -69,7 +69,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, input string) ([]float32, error) + Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } type EmbeddingRequest struct { - Content string `json:"content"` + Content string `json:"content"` + Truncate bool `json:"truncate"` } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding []float32 `json:"embedding"` + PromptEvalCount int `json:"prompt_eval_count"` } -func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { +func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { @@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err } else { slog.Error("Failed to acquire semaphore", "error", err) } - return nil, err + return nil, 0, err } defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { - return nil, err + return nil, 0, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status) + return nil, 0, fmt.Errorf("unexpected server status: %s", status) } - data, err := json.Marshal(EmbeddingRequest{Content: input}) + data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate}) if err != nil { - return nil, fmt.Errorf("error marshaling embed data: %w", err) + return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { - return nil, fmt.Errorf("error creating embed request: %w", err) + return nil, 0, fmt.Errorf("error creating embed request: %w", err) } r.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("do embedding request: %w", err) + return nil, 0, fmt.Errorf("do embedding request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading embed response: %w", err) + return nil, 0, fmt.Errorf("error reading embed response: %w", err) } if resp.StatusCode >= 400 { log.Printf("llm embedding error: %s", body) - return nil, fmt.Errorf("%s", body) + return nil, 0, api.StatusError{ + StatusCode: resp.StatusCode, + ErrorMessage: string(body), + } } var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { - return nil, fmt.Errorf("unmarshal tokenize response: %w", err) + return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) } - return e.Embedding, nil + return e.Embedding, e.PromptEvalCount, nil } type TokenizeRequest struct { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 87b43256..be899b3b 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -758,7 +758,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: embedding, + Embedding: embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 7b72bf92..b0cf6373 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -995,7 +995,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: <-seq.embedding, + Embedding: <-seq.embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/server/routes.go b/server/routes.go index 3d32a9aa..99a64d19 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,6 +21,7 @@ import ( "os/signal" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -659,7 +660,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -672,61 +673,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - kvData, _, err := getModelData(m.ModelPath, false) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - var count int - for i, s := range input { - tokens, err := r.Tokenize(c.Request.Context(), s) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) - return - } - - if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { - ctxLen-- - } - - if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { - ctxLen-- - } - - slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) - if ctxLen <= 0 { - // return error if the truncated input would be empty or just special tokens - c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) - return - } - - tokens = tokens[:ctxLen] - - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - - count += len(tokens) - - input[i] = s - } - var g errgroup.Group embeddings := make([][]float32, len(input)) + var totalTokens uint64 for i, text := range input { g.Go(func() error { - embedding, err := r.Embedding(c.Request.Context(), text) + embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate) if err != nil { return err } @@ -736,12 +688,18 @@ func (s *Server) EmbedHandler(c *gin.Context) { embedding = normalize(embedding[:req.Dimensions]) } embeddings[i] = embedding + atomic.AddUint64(&totalTokens, uint64(tokenCount)) return nil }) } if err := g.Wait(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + var serr api.StatusError + if errors.As(err, &serr) { + c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)}) + } else { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + } return } @@ -750,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: count, + PromptEvalCount: int(totalTokens), } c.JSON(http.StatusOK, resp) } @@ -796,7 +754,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return diff --git a/server/sched_test.go b/server/sched_test.go index 316a817f..c531e7eb 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) { + return s.embeddingResp, 0, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {