mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
server: Consolidate embedding truncation in runner (#12730)
Currently, checking the length of prompts for embeddings to ensure they fit in the context window (and possible truncation) occurs in two places - the Ollama server and runner. This can lead to inconsistencies in both the checks and reported number of tokens processed. Since we have to do this processing in the runner, this consolidates all of the logic there.
This commit is contained in:
@@ -4,7 +4,9 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -299,3 +301,216 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||||||
|
|
||||||
return client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedTruncation(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
t.Run("single input token count", func(t *testing.T) {
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount <= 0 {
|
||||||
|
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch parallel token counting", func(t *testing.T) {
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"cat", "dog and mouse", "bird"},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 3 {
|
||||||
|
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount <= 0 {
|
||||||
|
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncation single input", func(t *testing.T) {
|
||||||
|
truncTrue := true
|
||||||
|
longInput := strings.Repeat("word ", 100)
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: longInput,
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 50},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount > 50 {
|
||||||
|
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount == 0 {
|
||||||
|
t.Fatal("expected non-zero token count after truncation")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncation batch", func(t *testing.T) {
|
||||||
|
truncTrue := true
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 3 {
|
||||||
|
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount > 90 {
|
||||||
|
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncate false error", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: strings.Repeat("word ", 100),
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "exceeds maximum context length") {
|
||||||
|
t.Fatalf("expected context length error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||||
|
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
|
||||||
|
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"test", "test", "test"},
|
||||||
|
}
|
||||||
|
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCount := baseRes.PromptEvalCount * 3
|
||||||
|
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||||
|
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||||
|
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||||
|
// properly preserve their HTTP status codes when returned to the client.
|
||||||
|
// This test specifically checks the error handling path in EmbedHandler
|
||||||
|
// where api.StatusError errors should maintain their original status code.
|
||||||
|
func TestEmbedStatusCode(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Pull the model if needed
|
||||||
|
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
longInput := strings.Repeat("word ", 100)
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: longInput,
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||||
|
// not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error message is meaningful
|
||||||
|
if !strings.Contains(err.Error(), "context length") {
|
||||||
|
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{
|
||||||
|
"short input",
|
||||||
|
strings.Repeat("very long input ", 100),
|
||||||
|
"another short input",
|
||||||
|
},
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ type LlamaServer interface {
|
|||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Truncate bool `json:"truncate"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding []float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||||
logutil.Trace("embedding request", "input", input)
|
logutil.Trace("embedding request", "input", input)
|
||||||
|
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
@@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
|||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
defer s.sem.Release(1)
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatusRetry(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, 0, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
|
||||||
}
|
}
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(r)
|
resp, err := http.DefaultClient.Do(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
return nil, 0, fmt.Errorf("do embedding request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
log.Printf("llm embedding error: %s", body)
|
log.Printf("llm embedding error: %s", body)
|
||||||
return nil, fmt.Errorf("%s", body)
|
return nil, 0, api.StatusError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ErrorMessage: string(body),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var e EmbeddingResponse
|
var e EmbeddingResponse
|
||||||
if err := json.Unmarshal(body, &e); err != nil {
|
if err := json.Unmarshal(body, &e); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return e.Embedding, nil
|
return e.Embedding, e.PromptEvalCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|||||||
@@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||||
embedding: true,
|
embedding: true,
|
||||||
|
truncate: req.Truncate,
|
||||||
// TODO (jmorganca): this should be provided by the server via the
|
|
||||||
// request options and truncated here in the runner, instead of relying on
|
|
||||||
// the server's truncate logic
|
|
||||||
truncate: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, errorInputTooLong) {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -758,7 +758,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
embedding := <-seq.embedding
|
embedding := <-seq.embedding
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: embedding,
|
||||||
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||||
embedding: true,
|
embedding: true,
|
||||||
|
truncate: req.Truncate,
|
||||||
// TODO (jmorganca): this should be provided by the server via the
|
|
||||||
// request options and truncated here in the runner, instead of relying on
|
|
||||||
// the server's truncate logic
|
|
||||||
truncate: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, errorInputTooLong) {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -995,7 +995,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||||
Embedding: <-seq.embedding,
|
Embedding: <-seq.embedding,
|
||||||
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -659,7 +660,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
@@ -672,61 +673,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kvData, _, err := getModelData(m.ModelPath, false)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var count int
|
|
||||||
for i, s := range input {
|
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
||||||
if len(tokens) > ctxLen {
|
|
||||||
if !truncate {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
|
||||||
ctxLen--
|
|
||||||
}
|
|
||||||
|
|
||||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
|
||||||
ctxLen--
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
|
||||||
if ctxLen <= 0 {
|
|
||||||
// return error if the truncated input would be empty or just special tokens
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
count += len(tokens)
|
|
||||||
|
|
||||||
input[i] = s
|
|
||||||
}
|
|
||||||
|
|
||||||
var g errgroup.Group
|
var g errgroup.Group
|
||||||
embeddings := make([][]float32, len(input))
|
embeddings := make([][]float32, len(input))
|
||||||
|
var totalTokens uint64
|
||||||
for i, text := range input {
|
for i, text := range input {
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -736,12 +688,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
embedding = normalize(embedding[:req.Dimensions])
|
embedding = normalize(embedding[:req.Dimensions])
|
||||||
}
|
}
|
||||||
embeddings[i] = embedding
|
embeddings[i] = embedding
|
||||||
|
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
var serr api.StatusError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
|
||||||
|
} else {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -750,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
Embeddings: embeddings,
|
Embeddings: embeddings,
|
||||||
TotalDuration: time.Since(checkpointStart),
|
TotalDuration: time.Since(checkpointStart),
|
||||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
PromptEvalCount: count,
|
PromptEvalCount: int(totalTokens),
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
@@ -796,7 +754,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
|||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||||
return s.embeddingResp, s.embeddingRespErr
|
return s.embeddingResp, 0, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user