From a8d9c2648e1ea7fb0b209d49d42b54ac000c06ea Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 16 Oct 2025 16:27:45 -0700 Subject: [PATCH] llamarunner: Record the time for all batches during prompt processing Currently, we only record the time for the last batch when processing the prompt. This results in unrealistically high numbers for the old llama runner. Before: total duration: 31.273112939s load duration: 4.97054657s prompt eval count: 32768 token(s) prompt eval duration: 235.137439ms prompt eval rate: 139356.80 tokens/s eval count: 1873 token(s) eval duration: 18.173182374s eval rate: 103.06 tokens/s After: total duration: 30.024798033s load duration: 4.758588663s prompt eval count: 32768 token(s) prompt eval duration: 7.779621548s prompt eval rate: 4212.03 tokens/s eval count: 1769 token(s) eval duration: 17.148014223s eval rate: 103.16 tokens/s --- runner/llamarunner/runner.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a5e7eb33..87b43256 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) defer s.mu.Unlock() var batch *llama.Batch + var numOutputs int seqIdx := s.nextSeq - 1 for range s.seqs { @@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } - batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id) + output := i+1 == len(seq.inputs) + batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id) + if output { + numOutputs++ + } + seq.pendingInputs = append(seq.pendingInputs, input) seq.iBatch = batch.NumTokens() - 1 } @@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return fmt.Errorf("failed to decode batch: %w", err) } + if numOutputs > 0 { + s.lc.Synchronize() + } + for i, seq := range s.seqs { if seq == nil { continue @@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // don't sample prompt processing if len(seq.inputs) != 0 { + seq.processingDuration += time.Since(t) continue } - s.lc.Synchronize() seq.numDecoded++ if seq.numDecoded > 1 { seq.generationDuration += time.Since(t)