diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a5e7eb33..87b43256 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) defer s.mu.Unlock() var batch *llama.Batch + var numOutputs int seqIdx := s.nextSeq - 1 for range s.seqs { @@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } - batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id) + output := i+1 == len(seq.inputs) + batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id) + if output { + numOutputs++ + } + seq.pendingInputs = append(seq.pendingInputs, input) seq.iBatch = batch.NumTokens() - 1 } @@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return fmt.Errorf("failed to decode batch: %w", err) } + if numOutputs > 0 { + s.lc.Synchronize() + } + for i, seq := range s.seqs { if seq == nil { continue @@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // don't sample prompt processing if len(seq.inputs) != 0 { + seq.processingDuration += time.Since(t) continue } - s.lc.Synchronize() seq.numDecoded++ if seq.numDecoded > 1 { seq.generationDuration += time.Since(t)