Revert "server: Consolidate embedding truncation in runner (#12730)" (#12810)

This reverts commit 5d347f6d6f.
This commit is contained in:
Patrick Devine
2025-10-28 14:49:14 -07:00
committed by GitHub
parent 3d99d9779a
commit 29f63f37c8
6 changed files with 84 additions and 264 deletions

View File

@@ -21,7 +21,6 @@ import (
"os/signal"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"
@@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input {
g.Go(func() error {
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
@@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil
})
}
if err := g.Wait(); err != nil {
var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
} else {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
}
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return
}
@@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
PromptEvalCount: count,
}
c.JSON(http.StatusOK, resp)
}
@@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return