perf: build graph for next batch async to keep GPU busy (#11863)

* perf: build graph for next batch in parallel to keep GPU busy

This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.

* tests: tune integration tests for ollama engine

This tunes the integration tests to focus more on models supported
by the new engine.
This commit is contained in:
Daniel Hiltgen
2025-08-29 14:20:28 -07:00
committed by GitHub
parent ead4a9a1d0
commit 517807cdf2
20 changed files with 591 additions and 235 deletions

View File

@@ -86,7 +86,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []input.Input
Inputs []*input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input.Input, longest)
oldestSlot.Inputs = make([]*input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
var count int32
for i := range a {
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
}
type ErrReprocessInputs struct {
Inputs []input.Input
Inputs []*input.Input
}
func (e *ErrReprocessInputs) Error() string {
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
slot.Inputs = []input.Input{}
slot.Inputs = []*input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}

View File

@@ -13,50 +13,50 @@ import (
func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []input.Input
t2 []input.Input
t1 []*input.Input
t2 []*input.Input
expected int32
}{
{
name: "Equal",
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{{Token: 1}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
t1: []*input.Input{{MultimodalHash: 1}},
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input.Input{},
t2: []input.Input{},
t1: []*input.Input{},
t2: []*input.Input{},
expected: 0,
},
}
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
prompt []*input.Input
longest expected
best expected
}{
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input.Input{{Token: 1}},
prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input.Input{{Token: 2}},
prompt: []*input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input.Input{{Token: 1}},
prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
prompt []*input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
inputs []*input.Input
numKeep int32
cacheErr bool
wantErr any
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
Inputs: make([]*input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)

View File

@@ -17,6 +17,7 @@ import (
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
@@ -51,10 +52,10 @@ type Sequence struct {
iBatch int
// prompt inputs left to evaluate
inputs []input.Input
inputs []*input.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input.Input
pendingInputs []*input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -182,8 +183,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
var inputs []*input.Input
var ctxs []ml.Context
var mmStore multimodalStore
@@ -210,7 +211,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
inputs = append(inputs, &input.Input{Token: t})
}
// image - decode and store
@@ -243,7 +244,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
@@ -259,6 +260,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
return inputs, ctxs, mmStore, nil
}
type batchState struct {
// id provides a counter for trace logging batches
id int
// ctx holds the backend context used for this batch
ctx ml.Context
// modelOutput holds the outputs from this batch
modelOutput ml.Tensor
// batchInputs holds the input token pointers which may start as
// placeholders later filled in before calling ctx.Compute
batchInputs []*input.Input
// batch contains the inputs for a model forward pass
batch input.Batch
// full set of seqs at the time this batch was initiated
seqs []*Sequence
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
// Signaling when Compute is about to begin on this batch, and
// seqs have been updated to prepare for the next batch
computeStartedCh chan struct{}
// Signaled when this batches outputs are complete and the next batch can proceed
outputsReadyCh chan struct{}
}
type Server struct {
// modelPath is the location of the model to be loaded
modelPath string
@@ -290,6 +322,12 @@ type Server struct {
// TODO (jmorganca): make this n_batch
batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches
batchID int
// protects access to everything below this line
// this is context state needed for decoding
mu sync.Mutex
@@ -362,33 +400,66 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
s.seqsSem.Release(1)
}
// track batch state between forwardBatch, computeBatch and predictForwardBatch
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
var activeBatch batchState
for {
select {
case <-ctx.Done():
return
case err := <-s.hardErrCh:
panic(err)
default:
err := s.processBatch()
var err error
activeBatch, err = s.forwardBatch(activeBatch)
if err != nil {
panic(err)
}
go s.computeBatch(activeBatch)
}
}
}
func (s *Server) processBatch() error {
// forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
}
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
defer s.mu.Unlock()
ctx := s.model.Backend().NewContext()
defer ctx.Close()
nextBatch.ctx = s.model.Backend().NewContext()
defer func() {
if err != nil {
nextBatch.ctx.Close()
nextBatch.ctx = nil
}
}()
nextBatch.id = s.batchID
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
nextBatch.computeStartedCh = make(chan struct{}, 1)
nextBatch.outputsReadyCh = make(chan struct{}, 1)
var batchInputs []int32
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batch input.Batch
resumeSeq := -1
@@ -396,7 +467,6 @@ func (s *Server) processBatch() error {
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
@@ -404,12 +474,13 @@ func (s *Server) processBatch() error {
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input.Input{}
seq.cache.Inputs = []*input.Input{}
}
batchSize := s.batchSize
@@ -442,25 +513,28 @@ func (s *Server) processBatch() error {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
continue
} else {
return err
return
}
}
}
batchInputs = append(batchInputs, inp.Token)
batchInputs = append(batchInputs, seq.inputs[i])
if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
var mm []input.Multimodal
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
if err != nil {
return err
return
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
}
@@ -472,6 +546,7 @@ func (s *Server) processBatch() error {
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -485,36 +560,129 @@ func (s *Server) processBatch() error {
}
if len(batchInputs) == 0 {
return nil
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
}
s.batchID++
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
err = fmt.Errorf("failed to build graph: %w", err)
return
}
nextBatch.batchInputs = batchInputs
nextBatch.batch = batch
return
}
// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
if activeBatch.ctx == nil {
// Nothing to compute
return
}
defer activeBatch.ctx.Close()
// Wait until inputs are ready
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(activeBatch.batchInputs))
for i := range batchInputs {
batchInputs[i] = activeBatch.batchInputs[i].Token
}
logits := modelOutput.Floats()
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil {
continue
}
// Skip over any newly added or skipped sequences
if activeBatch.seqs[i] == nil {
continue
}
// After calling Forward, pending inputs are now in the cache
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
// Pending inputs will actually be in the cache after we call Compute.
// However, we have already resolved any placeholder tokens.
//
// It's possible for incoming sequences to look at the values that we've
// added to the cache here and start relying on them before we've done
// the computation. This is OK as long as we ensure that this batch's
// computation happens before any future batch's and we never fail
// (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input.Input{}
seq.pendingInputs = []*input.Input{}
}
// don't sample prompt processing
if len(seq.inputs) != 0 {
if !s.cache.enabled {
return errors.New("caching disabled but unable to fit entire input in a batch")
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
}
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
iBatches[i] = seq.iBatch
}
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
logits := activeBatch.modelOutput.Floats()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
s.mu.Lock()
defer s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now()
}
@@ -522,36 +690,38 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
}
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
return err
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
}
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -575,6 +745,7 @@ func (s *Server) processBatch() error {
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop)
@@ -593,8 +764,6 @@ func (s *Server) processBatch() error {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
return nil
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -736,7 +905,10 @@ func (s *Server) reserveWorstCaseGraph() error {
defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
inputs := make([]*input.Input, s.batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
mmStore := newMultimodalStore()
// Multimodal strategy:
@@ -778,8 +950,11 @@ func (s *Server) reserveWorstCaseGraph() error {
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
newInputs := make([]*input.Input, s.batchSize)
copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs
}
}
@@ -842,6 +1017,7 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok {
panicErr = err
} else {
@@ -1011,6 +1187,7 @@ func Execute(args []string) error {
server := &Server{
modelPath: *mpath,
status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
}
server.cond = sync.NewCond(&server.mu)