server: Consolidate embedding truncation in runner (#12730)

Currently, checking the length of prompts for embeddings to ensure
they fit in the context window (and possible truncation) occurs in
two places - the Ollama server and runner. This can lead to
inconsistencies in both the checks and reported number of tokens
processed. Since we have to do this processing in the runner, this
consolidates all of the logic there.
This commit is contained in:
nicole pardal
2025-10-27 11:59:12 -07:00
committed by GitHub
parent b97eb2b858
commit 5d347f6d6f
6 changed files with 264 additions and 84 deletions

View File

@@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error)
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbeddingRequest struct {
Content string `json:"content"`
Content string `json:"content"`
Truncate bool `json:"truncate"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
}
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil {
@@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return nil, err
return nil, 0, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return nil, err
return nil, 0, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status)
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input})
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err)
return nil, 0, fmt.Errorf("do embedding request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading embed response: %w", err)
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body)
return nil, fmt.Errorf("%s", body)
return nil, 0, api.StatusError{
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
}
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return e.Embedding, nil
return e.Embedding, e.PromptEvalCount, nil
}
type TokenizeRequest struct {