From 3c14461d5d2280723b3f961fb99ad128b3eee9af Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 5 May 2025 13:32:11 -0700 Subject: [PATCH] ollamarunner: Separate text and multimodal graphs For some multimodal models (such as gemma3), we create a single graph that generates the image embedding and then use this in the text model. The embedding tensor is completely opaque to the runner. However, this doesn't work if we need to use the embedding in multiple batches. This can arise if the embedding is larger than the batch size. In these cases (as with llama4), we would like to create views that are more appropriately sized. However, if we do this then the original source tensor is used in multiple graphs, which isn't allowed. To avoid that problem, models with this pattern compute the embedding tensor on first use and recreate the individual views. There is no longer a single vision and text graph. This codifies the pattern of separating vision and text graphs. The logic of computing tensors on demand is moved to the runner, so models no longer have to worry about this. It also gives the runner visibility into the multimodal tensors, which is important for memory management. --- model/input/input.go | 26 +++++-- model/model.go | 9 +-- model/models/gemma3/model.go | 14 ++-- model/models/gemma3/model_text.go | 2 +- model/models/llama4/model.go | 102 ++++++++++++++------------- model/models/llama4/model_text.go | 7 +- model/models/mistral3/model.go | 41 +++-------- model/models/mistral3/model_text.go | 18 +---- model/models/mllama/model.go | 8 ++- model/models/qwen25vl/model.go | 39 ++--------- model/models/qwen25vl/model_text.go | 7 +- runner/ollamarunner/cache_test.go | 17 ++--- runner/ollamarunner/multimodal.go | 103 ++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 34 ++++++--- 14 files changed, 241 insertions(+), 186 deletions(-) create mode 100644 runner/ollamarunner/multimodal.go diff --git a/model/input/input.go b/model/input/input.go index d66f52a0..bd9b53ec 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -2,16 +2,30 @@ package input import "github.com/ollama/ollama/ml" +// Multimodal is a multimodal embedding or a component of one. +// For example, it could be a row of an image that can be processed +// independently. +type Multimodal struct { + // Tensor is the embedding data. Implementations may chose what to + // store here or it may be nil if not needed. However, any ml.Tensor + // objects must be stored here and not in Data. + Tensor ml.Tensor + + // Data is implementation-specific opaque data, such as metadata on how + // to layout Tensor. It may be nil if not needed. It may also store larger + // objects such as complete images if they are to be processed later. + Data any +} + // Input represents one token in the input stream type Input struct { // Token is a single element of text. Token int32 - // Multimodal is opaque data representing a non-text - // element such as an image (or part of one if the image - // can be processed in pieces). It may be either together - // with Token or on its own. - Multimodal any + // Multimodal is represents a non-text element such as an + // image (or part of one if the image can be processed in pieces). + // It may be used either together with Token or on its own. + Multimodal []Multimodal // MultimodalHash is a unique representation of the data // stored in Multimodal, used for caching and comparing @@ -32,7 +46,7 @@ type Input struct { // Positions slice. type MultimodalIndex struct { Index int - Multimodal any + Multimodal []Multimodal } // Batch contains the inputs for a model forward pass diff --git a/model/model.go b/model/model.go index 7883b851..98381c90 100644 --- a/model/model.go +++ b/model/model.go @@ -40,12 +40,13 @@ type MultimodalProcessor interface { // EncodeMultimodal processes a single input (such as an image) and // generates an output (typically an embedding) that can be used by the model. // - // The return value is most typically an ml.Tensor, however, different - // type are possible, such as an object containing a tensor plus - // additional metadata, a slice of tensors or even just the original input. + // The return value is one or more tensors, each with optional model-specific + // opaque metadata. Typically, the tensors might be views into an embedding + // with each view representing a chunk of data that can be processed independently + // in different batches. // // The result may be cached by the runner. - EncodeMultimodal(ml.Context, []byte) (any, error) + EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error) // PostTokenize is called after tokenization to allow the model to edit the // input stream to correctly arrange multimodal elements. diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index bf396b6a..d53eb6cc 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -82,7 +82,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -108,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) - return visionOutputs, nil + return []input.Multimodal{{Tensor: visionOutputs}}, nil } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.(ml.Tensor) + inputMultimodal := inp.Multimodal[0].Tensor result = append(result, - input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" - input.Input{Token: 255999}, // """ - input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder ) // add image token placeholders diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 741818a2..a40614af 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -165,7 +165,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // set image embeddings var except []int for _, image := range batch.Multimodal { - visionOutputs := image.Multimodal.(ml.Tensor) + visionOutputs := image.Multimodal[0].Tensor ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) for i := range visionOutputs.Dim(1) { diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 798f0d16..c94aa72f 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -4,7 +4,6 @@ import ( "bytes" "image" "slices" - "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -63,7 +62,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) < 1 { return nil, model.ErrNoVisionModel } @@ -103,70 +102,79 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) projectedOutputs := m.Projector.Forward(ctx, visionOutputs) - return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil + + var multimodal []input.Multimodal + aspectRatio := image.Point{ratioW, ratioH} + + var offset int + patchesPerChunk := projectedOutputs.Dim(1) + if aspectRatio.Y*aspectRatio.X > 1 { + patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1) + + for range aspectRatio.Y { + for x := range aspectRatio.X { + view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, + projectedOutputs.Dim(0), projectedOutputs.Stride(1), + patchesPerChunk) + var separator separator + if x < aspectRatio.X-1 { + separator.x = true // <|tile_x_separator|> + } else { + separator.y = true // <|tile_y_separator|> + } + multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator}) + offset += patchesPerChunk + } + } + } + + view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, + projectedOutputs.Dim(0), projectedOutputs.Stride(1), + patchesPerChunk) + multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) + + return multimodal, nil } -type chunks struct { - *Model - ml.Tensor - aspectRatio image.Point - - dataOnce sync.Once - data []float32 -} - -type chunk struct { - *chunks - s, n int -} - -func (r *chunk) floats() []float32 { - r.dataOnce.Do(func() { - temp := r.Backend().NewContext() - defer temp.Close() - temp.Forward(r.Tensor).Compute(r.Tensor) - r.data = r.Floats() - }) - - return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] +type separator struct { + x bool + y bool } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) continue } - t := inp.Multimodal.(*chunks) var imageInputs []input.Input imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> - var offset int - patchesPerChunk := t.Dim(1) - if t.aspectRatio.Y*t.aspectRatio.X > 1 { - patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1) + for i, mm := range inp.Multimodal { + patchesPerChunk := mm.Tensor.Dim(1) - for range t.aspectRatio.Y { - for x := range t.aspectRatio.X { - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - if x < t.aspectRatio.X-1 { - imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> - } - offset += patchesPerChunk + if i < len(inp.Multimodal)-1 { + separator := mm.Data.(*separator) + + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + + if separator.x { + imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> } - - imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + if separator.y { + imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + } + } else { + imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> + imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> } } - imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> - result = append(result, imageInputs...) } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 3f9f578f..d98587bd 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -210,12 +210,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) for _, mi := range batch.Multimodal { - f32s := mi.Multimodal.(*chunk).floats() - img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) - if err != nil { - panic(err) - } - + img := mi.Multimodal[0].Tensor ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index c9685244..b93882a9 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -4,7 +4,6 @@ import ( "bytes" "image" "slices" - "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -105,7 +104,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector { } } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -129,37 +128,14 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) // split into patches to be sent to the text transformer - parent := imageFeatures{tensor: features} - rows := make([]*imageRow, size.Y) + rows := make([]input.Multimodal, size.Y) for i := range rows { - rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} + rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X) } return rows, nil } -type imageFeatures struct { - tensor ml.Tensor - - dataOnce sync.Once - data []float32 -} - -type imageRow struct { - parent *imageFeatures - s int - shape []int -} - -func (r *imageRow) data() []float32 { - n := 1 - for _, s := range r.shape { - n *= s - } - - return r.parent.data[r.s*n : (r.s+1)*n] -} - // PostTokenize arranges Mistral 3's inputs for the forward pass // In Mistral 3 and Pixtral, the input patches are arranged as follows: // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] @@ -168,15 +144,14 @@ func (r *imageRow) data() []float32 { func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { - if inp.Multimodal == nil { + if len(inp.Multimodal) == 0 { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.([]*imageRow) - for i, row := range inputMultimodal { + for i, row := range inp.Multimodal { // [IMG] - result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]}) - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) - if i == len(inputMultimodal)-1 { + result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) + if i == len(inp.Multimodal)-1 { // [IMG_END] result = append(result, input.Input{Token: 13}) } else { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 565b001a..17939800 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -9,7 +9,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) @@ -20,8 +19,6 @@ type TextOptions struct { } type TextModel struct { - model.Base - TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` @@ -109,20 +106,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // image embeddings for _, image := range batch.Multimodal { - row := image.Multimodal.(*imageRow) - row.parent.dataOnce.Do(func() { - // use a new, throwaway context so the image tensor is not added to the graph - temp := m.Backend().NewContext() - temp.Forward(row.parent.tensor).Compute(row.parent.tensor) - row.parent.data = row.parent.tensor.Floats() - temp.Close() - }) - - imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...) - if err != nil { - panic(err) - } - + imageFeature := image.Multimodal[0].Tensor ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 4d5bdd4a..15571d9c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -92,7 +92,9 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) - return m.Projector.Forward(ctx, crossAttentionStates), nil + projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates) + + return []input.Multimodal{{Tensor: projectedOutputs}}, nil } func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { @@ -108,7 +110,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(batch.Multimodal) > 0 { - crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal.(ml.Tensor) + crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor } positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 9d243c30..48655450 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -5,7 +5,6 @@ import ( "fmt" "image" "slices" - "sync" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -77,7 +76,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * return pixelValues, grid, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel } @@ -88,31 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er } visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) - return &chunks{Model: m, Tensor: visionOutputs}, nil -} - -type chunks struct { - *Model - ml.Tensor - - dataOnce sync.Once - data []float32 -} - -type chunk struct { - *chunks - s, n int -} - -func (r *chunk) floats() []float32 { - r.dataOnce.Do(func() { - temp := r.Backend().NewContext() - defer temp.Close() - temp.Forward(r.Tensor).Compute(r.Tensor) - r.data = r.Floats() - }) - - return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)] + return []input.Multimodal{{Tensor: visionOutputs}}, nil } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass @@ -142,20 +117,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { result = append(result, input.Input{Token: pre[i]}) } - // This is an image token with multimodal data - chunksData := inp.Multimodal.(*chunks) - patchesPerChunk := chunksData.Dim(1) + patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2}) + result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 1}) // Add the image token with the multimodal tensor data at the first position - // Create a chunk with proper s and n values result = append(result, input.Input{ Token: imageToken, - Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk}, + Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, - SameBatch: patchesPerChunk, }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 6b062f8c..800fd961 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -129,12 +129,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) for _, mi := range batch.Multimodal { - f32s := mi.Multimodal.(*chunk).floats() - img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize) - if err != nil { - panic(err) - } - + img := mi.Multimodal[0].Tensor ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 062b654c..6897b5e4 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -3,7 +3,6 @@ package ollamarunner import ( "errors" "fmt" - "image" "testing" "time" @@ -12,10 +11,6 @@ import ( ) func TestCountCommon(t *testing.T) { - imgA := image.NewRGBA(image.Rect(0, 0, 100, 100)) - imgB := image.NewRGBA(image.Rect(0, 0, 50, 50)) - imgC := image.NewRGBA(image.Rect(50, 50, 100, 100)) - tests := []struct { name string t1 []input.Input @@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) { }, { name: "Image Prefix", - t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, + t1: []input.Input{{MultimodalHash: 1}}, + t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, + t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, expected: 1, }, { diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go new file mode 100644 index 00000000..16d35921 --- /dev/null +++ b/runner/ollamarunner/multimodal.go @@ -0,0 +1,103 @@ +package ollamarunner + +import ( + "errors" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +// Tensors can't be used across multiple compute graphs. This is a problem +// if a single embedding is split across batches using views since all of +// the views will have the same source tensor. We also don't want to +// recompute the entire embedding for each batch. +// +// To avoid this, we compute all of the tensors for the embedding on the +// first use and then store the result in system memory. When we need +// additional tensors, we recreate them from the stored data. + +// multimodalEntry represents the embeddings of a single object (such +// as an image). +type multimodalEntry struct { + // mm is the original set of tensors created by EncodeMultimodal + mm []input.Multimodal + + // data is the computed result of mm. Nil if not yet computed + data [][]float32 +} + +// multimodalStore maps from an individual tensor (of which there +// may be many in a single multimodal object) to its parent embedding +type multimodalStore map[ml.Tensor]*multimodalEntry + +func newMultimodalStore() multimodalStore { + return make(multimodalStore) +} + +// addMultimodal stores an embedding for later use in a compute graph +func (m multimodalStore) addMultimodal(embedding []input.Multimodal) { + entry := &multimodalEntry{mm: embedding} + + for _, e := range embedding { + if e.Tensor != nil { + m[e.Tensor] = entry + } + } +} + +// getMultimodal takes a source set of tensors (which may contain a whole or +// parts of one or more images) and returns the equivalent that can be used in +// the current context +func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) { + out := make([]input.Multimodal, len(in)) + for i := range out { + if in[i].Tensor != nil { + var err error + out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor) + if err != nil { + return nil, err + } + } + + out[i].Data = in[i].Data + } + + return out, nil +} + +func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) { + entry := m[in] + + if entry.data == nil { + computeCtx := backend.NewContext() + defer computeCtx.Close() + + var tensors []ml.Tensor + for _, t := range entry.mm { + if t.Tensor != nil { + tensors = append(tensors, t.Tensor) + } + } + + if len(tensors) == 0 { + return nil, nil + } + + computeCtx.Forward(tensors...).Compute(tensors...) + + entry.data = make([][]float32, len(entry.mm)) + for i, t := range entry.mm { + if t.Tensor != nil { + entry.data[i] = t.Tensor.Floats() + } + } + } + + for i, t := range entry.mm { + if in == t.Tensor { + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + } + } + + return nil, errors.New("multimodal tensor not found") +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a522223..4e203b7b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -40,6 +40,9 @@ type Sequence struct { // multimodal embeddings ctxs []ml.Context + // mmStore holds multimodal embeddings to mange memory and enable splitting across batches + mmStore multimodalStore + // batch index iBatch int @@ -101,7 +104,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe startTime := time.Now() - inputs, ctxs, err := s.inputs(prompt, images) + inputs, ctxs, mmStore, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -156,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe return &Sequence{ ctxs: ctxs, + mmStore: mmStore, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -174,9 +178,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { var inputs []input.Input var ctxs []ml.Context + var mmStore multimodalStore var parts []string var matches [][]string @@ -187,6 +192,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ re := regexp.MustCompile(`\[img-(\d+)\]`) parts = re.Split(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1) + mmStore = newMultimodalStore() } else { parts = []string{prompt} } @@ -196,7 +202,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, nil, err + return nil, nil, nil, err } for _, t := range tokens { @@ -216,7 +222,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ } if imageIndex < 0 { - return nil, nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, nil, fmt.Errorf("invalid image index: %d", n) } ctx := s.model.Backend().NewContext() @@ -224,13 +230,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ctxs = append(ctxs, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, nil, err + return nil, nil, nil, err } s.multimodalHash.Reset() _, _ = s.multimodalHash.Write(images[imageIndex].Data) imageHash := s.multimodalHash.Sum64() + mmStore.addMultimodal(imageEmbeddings) + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } @@ -240,11 +248,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ var err error inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } - return inputs, ctxs, nil + return inputs, ctxs, mmStore, nil } type Server struct { @@ -363,6 +371,9 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() + ctx := s.model.Backend().NewContext() + defer ctx.Close() + var batchInputs []int32 var batch input.Batch @@ -433,7 +444,11 @@ func (s *Server) processBatch() error { batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) + mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal) + if err != nil { + return err + } + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) } batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) @@ -459,9 +474,6 @@ func (s *Server) processBatch() error { return nil } - ctx := s.model.Backend().NewContext() - defer ctx.Close() - modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err)