perf: build graph for next batch async to keep GPU busy (#11863)

* perf: build graph for next batch in parallel to keep GPU busy

This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.

* tests: tune integration tests for ollama engine

This tunes the integration tests to focus more on models supported
by the new engine.
This commit is contained in:
Daniel Hiltgen
2025-08-29 14:20:28 -07:00
committed by GitHub
parent ead4a9a1d0
commit 517807cdf2
20 changed files with 591 additions and 235 deletions

View File

@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]input.Input) ([]input.Input, error)
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -278,7 +278,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
@@ -287,8 +287,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, errors.New("batch size cannot be less than 1")
}
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
@@ -302,7 +300,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, err
}
ctx.Forward(t).Compute(t)
ctx.Forward(t)
return t, nil
}