mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-25 07:58:01 +00:00
perf: build graph for next batch async to keep GPU busy (#11863)
* perf: build graph for next batch in parallel to keep GPU busy This refactors the main run loop of the ollama runner to perform the main GPU intensive tasks (Compute+Floats) in a go routine so we can prepare the next batch in parallel to reduce the amount of time the GPU stalls waiting for the next batch of work. * tests: tune integration tests for ollama engine This tunes the integration tests to focus more on models supported by the new engine.
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -51,10 +52,10 @@ type Sequence struct {
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input.Input
|
||||
pendingInputs []*input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
@@ -182,8 +183,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []*input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
@@ -210,7 +211,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
inputs = append(inputs, &input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
@@ -243,7 +244,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
@@ -259,6 +260,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type batchState struct {
|
||||
// id provides a counter for trace logging batches
|
||||
id int
|
||||
|
||||
// ctx holds the backend context used for this batch
|
||||
ctx ml.Context
|
||||
|
||||
// modelOutput holds the outputs from this batch
|
||||
modelOutput ml.Tensor
|
||||
|
||||
// batchInputs holds the input token pointers which may start as
|
||||
// placeholders later filled in before calling ctx.Compute
|
||||
batchInputs []*input.Input
|
||||
|
||||
// batch contains the inputs for a model forward pass
|
||||
batch input.Batch
|
||||
|
||||
// full set of seqs at the time this batch was initiated
|
||||
seqs []*Sequence
|
||||
|
||||
// Signaled when this batches inputs are ready and compute can proceed
|
||||
inputsReadyCh chan struct{}
|
||||
|
||||
// Signaling when Compute is about to begin on this batch, and
|
||||
// seqs have been updated to prepare for the next batch
|
||||
computeStartedCh chan struct{}
|
||||
|
||||
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||
outputsReadyCh chan struct{}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
@@ -290,6 +322,12 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
@@ -362,33 +400,66 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
err := s.processBatch()
|
||||
var err error
|
||||
activeBatch, err = s.forwardBatch(activeBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go s.computeBatch(activeBatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
// forwardBatch will calculate a batch.
|
||||
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
|
||||
// If we have a pending batch still processing, wait until Compute has started
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if pendingBatch.ctx != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
<-pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||
nextBatch.inputsReadyCh <- struct{}{}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
nextBatch.ctx = s.model.Backend().NewContext()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
}
|
||||
}()
|
||||
nextBatch.id = s.batchID
|
||||
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
|
||||
nextBatch.computeStartedCh = make(chan struct{}, 1)
|
||||
nextBatch.outputsReadyCh = make(chan struct{}, 1)
|
||||
|
||||
var batchInputs []int32
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -396,7 +467,6 @@ func (s *Server) processBatch() error {
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
@@ -404,12 +474,13 @@ func (s *Server) processBatch() error {
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
@@ -442,25 +513,28 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||
err = nil
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
batchInputs = append(batchInputs, seq.inputs[i])
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
var mm []input.Multimodal
|
||||
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
@@ -472,6 +546,7 @@ func (s *Server) processBatch() error {
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -485,36 +560,129 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
return
|
||||
}
|
||||
s.batchID++
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
return
|
||||
}
|
||||
nextBatch.batchInputs = batchInputs
|
||||
nextBatch.batch = batch
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Async processing of the next batch
|
||||
func (s *Server) computeBatch(activeBatch batchState) {
|
||||
if activeBatch.ctx == nil {
|
||||
// Nothing to compute
|
||||
return
|
||||
}
|
||||
defer activeBatch.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
<-activeBatch.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
activeBatch.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Gather the actual input token values now that they're ready
|
||||
batchInputs := make([]int32, len(activeBatch.batchInputs))
|
||||
for i := range batchInputs {
|
||||
batchInputs[i] = activeBatch.batchInputs[i].Token
|
||||
}
|
||||
|
||||
logits := modelOutput.Floats()
|
||||
|
||||
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||
// decoded tokens.
|
||||
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||
for i, seq := range s.seqs {
|
||||
iBatches[i] = -1
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added or skipped sequences
|
||||
if activeBatch.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != activeBatch.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// Pending inputs will actually be in the cache after we call Compute.
|
||||
// However, we have already resolved any placeholder tokens.
|
||||
//
|
||||
// It's possible for incoming sequences to look at the values that we've
|
||||
// added to the cache here and start relying on them before we've done
|
||||
// the computation. This is OK as long as we ensure that this batch's
|
||||
// computation happens before any future batch's and we never fail
|
||||
// (unless we take down the whole runner).
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input.Input{}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
iBatches[i] = seq.iBatch
|
||||
}
|
||||
|
||||
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||
s.mu.Unlock()
|
||||
|
||||
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||
activeBatch.ctx.ComputeWithNotify(
|
||||
func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
activeBatch.computeStartedCh <- struct{}{}
|
||||
},
|
||||
activeBatch.modelOutput)
|
||||
logits := activeBatch.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
@@ -522,36 +690,38 @@ func (s *Server) processBatch() error {
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
@@ -575,6 +745,7 @@ func (s *Server) processBatch() error {
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
@@ -593,8 +764,6 @@ func (s *Server) processBatch() error {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -736,7 +905,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
@@ -778,8 +950,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
@@ -842,6 +1017,7 @@ func (s *Server) allocModel(
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
@@ -1011,6 +1187,7 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
Reference in New Issue
Block a user