ollamarunner: measure only active time

This commit is contained in:
Michael Yang
2025-09-29 12:29:26 -07:00
committed by Michael Yang
parent bbbc73d637
commit 967a82f52f

View File

@@ -91,10 +91,11 @@ type Sequence struct {
doneReason llm.DoneReason
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
numPredicted int
numPromptInputs int
startedAt, lastUpdatedAt time.Time
processingDuration time.Duration
samplingDuration time.Duration
numPredicted int
numPromptInputs int
}
type NewSequenceParams struct {
@@ -108,8 +109,6 @@ type NewSequenceParams struct {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
@@ -164,20 +163,19 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctxs: ctxs,
mmStore: mmStore,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
ctxs: ctxs,
mmStore: mmStore,
inputs: inputs,
numPromptInputs: len(inputs),
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}, nil
}
@@ -408,7 +406,7 @@ func (s *Server) run(ctx context.Context) {
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
var previousBatch batchState
for {
select {
case <-ctx.Done():
@@ -417,16 +415,18 @@ func (s *Server) run(ctx context.Context) {
panic(err)
default:
var err error
activeBatch, err = s.forwardBatch(activeBatch)
nextBatch, err := s.forwardBatch(previousBatch)
if err != nil {
panic(err)
}
if supportsAsync {
go s.computeBatch(activeBatch)
go s.computeBatch(nextBatch)
} else {
s.computeBatch(activeBatch)
s.computeBatch(nextBatch)
}
previousBatch = nextBatch
}
}
}
@@ -562,6 +562,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
startedAt := time.Now()
for i := range nextBatch.seqs {
if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
nextBatch.seqs[i].startedAt = startedAt
}
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
@@ -682,6 +689,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
activeBatch.modelOutput)
outputs := activeBatch.modelOutput.Floats()
t := time.Now()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
@@ -694,8 +702,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.lastUpdatedAt = t
if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now()
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
}
// if done processing the prompt, generate an embedding and return
@@ -774,6 +784,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
samplingDuration := time.Since(t)
for i, seq := range s.seqs {
if seq != nil && nextBatchTokens[i] != nil {
s.seqs[i].samplingDuration += samplingDuration
}
}
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -887,9 +904,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
PromptEvalDuration: seq.processingDuration,
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}