diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index fafd850b..22ec7685 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -91,10 +91,11 @@ type Sequence struct { doneReason llm.DoneReason // Metrics - startProcessingTime time.Time - startGenerationTime time.Time - numPredicted int - numPromptInputs int + startedAt, lastUpdatedAt time.Time + processingDuration time.Duration + samplingDuration time.Duration + numPredicted int + numPromptInputs int } type NewSequenceParams struct { @@ -108,8 +109,6 @@ type NewSequenceParams struct { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() - startTime := time.Now() - inputs, ctxs, mmStore, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) @@ -164,20 +163,19 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctxs: ctxs, - mmStore: mmStore, - inputs: inputs, - numPromptInputs: len(inputs), - startProcessingTime: startTime, - numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), - quit: make(chan bool, 1), - embedding: make(chan []float32, 1), - sampler: params.sampler, - embeddingOnly: params.embedding, - stop: params.stop, - numKeep: params.numKeep, + ctxs: ctxs, + mmStore: mmStore, + inputs: inputs, + numPromptInputs: len(inputs), + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + sampler: params.sampler, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, }, nil } @@ -408,7 +406,7 @@ func (s *Server) run(ctx context.Context) { supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone - var activeBatch batchState + var previousBatch batchState for { select { case <-ctx.Done(): @@ -417,16 +415,18 @@ func (s *Server) run(ctx context.Context) { panic(err) default: var err error - activeBatch, err = s.forwardBatch(activeBatch) + nextBatch, err := s.forwardBatch(previousBatch) if err != nil { panic(err) } if supportsAsync { - go s.computeBatch(activeBatch) + go s.computeBatch(nextBatch) } else { - s.computeBatch(activeBatch) + s.computeBatch(nextBatch) } + + previousBatch = nextBatch } } } @@ -562,6 +562,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er seq.inputs = seq.inputs[len(seq.pendingInputs):] } + startedAt := time.Now() + for i := range nextBatch.seqs { + if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() { + nextBatch.seqs[i].startedAt = startedAt + } + } + if resumeSeq != -1 { s.nextSeq = resumeSeq } else { @@ -682,6 +689,7 @@ func (s *Server) computeBatch(activeBatch batchState) { activeBatch.modelOutput) outputs := activeBatch.modelOutput.Floats() + t := time.Now() logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id) @@ -694,8 +702,10 @@ func (s *Server) computeBatch(activeBatch batchState) { continue } + seq.lastUpdatedAt = t if seq.numPredicted == 1 { - seq.startGenerationTime = time.Now() + seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt) + seq.startedAt = seq.lastUpdatedAt } // if done processing the prompt, generate an embedding and return @@ -774,6 +784,13 @@ func (s *Server) computeBatch(activeBatch batchState) { s.removeSequence(i, llm.DoneReasonConnectionClosed) } } + + samplingDuration := time.Since(t) + for i, seq := range s.seqs { + if seq != nil && nextBatchTokens[i] != nil { + s.seqs[i].samplingDuration += samplingDuration + } + } } func (s *Server) completion(w http.ResponseWriter, r *http.Request) { @@ -887,9 +904,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, - PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + PromptEvalDuration: seq.processingDuration, EvalCount: seq.numPredicted, - EvalDuration: time.Since(seq.startGenerationTime), + EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) }