mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
server: Consolidate embedding truncation in runner (#12730)
Currently, checking the length of prompts for embeddings to ensure they fit in the context window (and possible truncation) occurs in two places - the Ollama server and runner. This can lead to inconsistencies in both the checks and reported number of tokens processed. Since we have to do this processing in the runner, this consolidates all of the logic there.
This commit is contained in:
@@ -69,7 +69,7 @@ type LlamaServer interface {
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@@ -1545,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content"`
|
||||
Truncate bool `json:"truncate"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||
logutil.Trace("embedding request", "input", input)
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
@@ -1561,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
return nil, 0, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm embedding error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
return nil, 0, api.StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorMessage: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
var e EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &e); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return e.Embedding, nil
|
||||
return e.Embedding, e.PromptEvalCount, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
||||
Reference in New Issue
Block a user