diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index ae26b52b..7ed7ebb2 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -82,10 +82,10 @@ type Sequence struct { doneReason llm.DoneReason // Metrics - startProcessingTime time.Time - startGenerationTime time.Time - numDecoded int - numPromptInputs int + processingDuration time.Duration + generationDuration time.Duration + numDecoded int + numPromptInputs int } type NewSequenceParams struct { @@ -99,8 +99,6 @@ type NewSequenceParams struct { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() - startTime := time.Now() - inputs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) @@ -142,18 +140,17 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe } return &Sequence{ - inputs: inputs, - numPromptInputs: len(inputs), - startProcessingTime: startTime, - numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), - quit: make(chan bool, 1), - embedding: make(chan []float32, 1), - samplingCtx: sc, - embeddingOnly: params.embedding, - stop: params.stop, - numKeep: params.numKeep, + inputs: inputs, + numPromptInputs: len(inputs), + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplingCtx: sc, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, }, nil } @@ -438,8 +435,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } - err := s.lc.Decode(batch) - if err != nil { + t := time.Now() + if err := s.lc.Decode(batch); err != nil { return fmt.Errorf("failed to decode batch: %w", err) } @@ -459,9 +456,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - seq.numDecoded += 1 - if seq.numDecoded == 1 { - seq.startGenerationTime = time.Now() + s.lc.Synchronize() + seq.numDecoded++ + if seq.numDecoded > 1 { + seq.generationDuration += time.Since(t) + } else { + seq.processingDuration += time.Since(t) } // if done processing the prompt, generate an embedding and return @@ -646,9 +646,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, - PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + PromptEvalDuration: seq.processingDuration, EvalCount: seq.numDecoded, - EvalDuration: time.Since(seq.startGenerationTime), + EvalDuration: seq.generationDuration, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) }