embedding gemma model (#12181)

* ollama: add embeddings
This commit is contained in:
Michael Yang
2025-09-04 09:09:07 -07:00
committed by GitHub
parent b3e6120736
commit 5994e8e8fd
10 changed files with 175 additions and 27 deletions

View File

@@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -113,6 +113,10 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
// Check error state
if (err != nil) != tt.wantErr {

View File

@@ -11,6 +11,7 @@ import (
"image"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
@@ -405,6 +406,8 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
var activeBatch batchState
for {
select {
@@ -418,7 +421,12 @@ func (s *Server) run(ctx context.Context) {
if err != nil {
panic(err)
}
go s.computeBatch(activeBatch)
if supportsAsync {
go s.computeBatch(activeBatch)
} else {
s.computeBatch(activeBatch)
}
}
}
}
@@ -670,7 +678,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
logits := activeBatch.modelOutput.Floats()
outputs := activeBatch.modelOutput.Floats()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
@@ -689,16 +698,15 @@ func (s *Server) computeBatch(activeBatch batchState) {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
seq.embedding <- outputs
s.removeSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
@@ -834,7 +842,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
@@ -890,6 +898,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embedding request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
@@ -1206,10 +1275,7 @@ func Execute(args []string) error {
mux := http.NewServeMux()
// TODO: support embeddings
mux.HandleFunc("POST /load", server.load)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /embedding", server.embeddings)
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)