mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
truncation: fixed runner truncation logic + removed server truncation (#12839)
This PR consolidates all embedding prompt-length checking, truncation, and prompt token counting into the runner to ensure a single source of truth.
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -649,11 +650,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -701,55 +697,57 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
ctx := c.Request.Context()
|
||||
|
||||
embedWithRetry := func(text string) ([]float32, int, error) {
|
||||
emb, tokCount, err := r.Embedding(ctx, text)
|
||||
if err == nil {
|
||||
return emb, tokCount, nil
|
||||
}
|
||||
|
||||
var serr api.StatusError
|
||||
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
|
||||
return nil, 0, err
|
||||
}
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
tokens, err := r.Tokenize(ctx, text)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||
if ctxLen <= 0 {
|
||||
// return error if the truncated input would be empty or just special tokens
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
if len(tokens) <= ctxLen {
|
||||
return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further")
|
||||
}
|
||||
if ctxLen <= 0 {
|
||||
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
|
||||
}
|
||||
|
||||
input[i] = s
|
||||
truncatedTokens := tokens[:ctxLen]
|
||||
truncated, err := r.Detokenize(ctx, truncatedTokens)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return r.Embedding(ctx, truncated)
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
var totalTokens uint64
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
embedding, tokenCount, err := embedWithRetry(text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -759,12 +757,23 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
embedding = normalize(embedding[:req.Dimensions])
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
c.AbortWithStatusJSON(serr.StatusCode, gin.H{
|
||||
"error": strings.TrimSpace(serr.ErrorMessage),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
||||
"error": strings.TrimSpace(err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -773,7 +782,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: count,
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -819,7 +828,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
return
|
||||
|
||||
@@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||
return s.completionResp
|
||||
}
|
||||
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return s.embeddingResp, 0, s.embeddingRespErr
|
||||
}
|
||||
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
Reference in New Issue
Block a user