chore: update mllama to use ollama engine (#10637)

This commit is contained in:
Michael Yang
2025-05-13 17:36:02 -07:00
committed by GitHub
parent 0478d440f0
commit 23125648b8
67 changed files with 785 additions and 4354 deletions

View File

@@ -3,47 +3,32 @@ package server
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
var errTooManyImages = errors.New("vision model only supports a single image per message")
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
isMllama := checkMllamaModelFamily(m)
var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
if isMllama {
// Our mllama implementation packs all of the embeddings into a single token
imageNumTokens = 1
} else {
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens = 768
}
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens := 768
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
// always include the last message
if i == n {
continue
@@ -84,48 +69,17 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
imgPrompt := ""
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
return "", nil, errors.New("this model only supports one image while more than one image requested")
}
var prefix string
prompt := msg.Content
for _, i := range msg.Images {
var imgData llm.ImageData
if isMllama {
if len(m.ProjectorPaths) == 0 {
imgData = llm.ImageData{
ID: len(images),
Data: i,
}
} else {
data, opts, err := mllama.Preprocess(bytes.NewReader(i))
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
ar, ok := opts["aspectRatioIndex"].(int)
if !ok {
return "", nil, fmt.Errorf("missing aspect ratio for image")
}
imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(),
AspectRatioID: ar,
}
}
imgPrompt = "<|image|>"
} else {
imgData = llm.ImageData{
ID: len(images),
Data: i,
}
imgData := llm.ImageData{
ID: len(images),
Data: i,
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
@@ -137,7 +91,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
msgs[currMsgIdx+cnt].Content = prefix + prompt
}
// truncate any messages that do not fit into the context window
@@ -148,12 +102,3 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return b.String(), images, nil
}
func checkMllamaModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "mllama" {
return true
}
}
return false
}

View File

@@ -2,8 +2,6 @@ package server
import (
"bytes"
"image"
"image/png"
"testing"
"github.com/google/go-cmp/cmp"
@@ -14,10 +12,9 @@ import (
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
images [][]byte
aspectRatioID int
error error
prompt string
images [][]byte
error error
}
tmpl, err := template.Parse(`
@@ -28,28 +25,6 @@ func TestChatPrompt(t *testing.T) {
t.Fatal(err)
}
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
createImg := func(width, height int) ([]byte, error) {
img := image.NewRGBA(image.Rect(0, 0, width, height))
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
imgBuf, err := createImg(5, 5)
if err != nil {
t.Fatal(err)
}
imgBuf2, err := createImg(6, 6)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
@@ -227,90 +202,6 @@ func TestChatPrompt(t *testing.T) {
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
},
},
{
name: "messages with mllama (no images)",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "messages with mllama single prompt",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "multiple messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
},
expect: expect{
prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf, imgBuf2},
aspectRatioID: 1,
},
},
{
name: "earlier image with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "There are four hotdogs."},
{Role: "user", Content: "Which ones have mustard?"},
},
expect: expect{
prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "too many images with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
},
expect: expect{
error: errTooManyImages,
},
},
}
for _, tt := range cases {
@@ -341,10 +232,6 @@ func TestChatPrompt(t *testing.T) {
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
}
} else {
if images[i].AspectRatioID != tt.aspectRatioID {
t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID)
}
}
}
})

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"cmp"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
@@ -35,7 +34,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
@@ -100,6 +98,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, err
}
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
@@ -206,38 +208,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(m)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image while more than one image requested"})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
if isMllama && len(m.ProjectorPaths) > 0 {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
ar, ok := opts["aspectRatioIndex"].(int)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: ar}
} else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
prompt := req.Prompt
@@ -269,9 +247,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
for _, i := range images {
imgPrompt := ""
if isMllama {
imgPrompt = "<|image|>"
}
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
}

View File

@@ -8,6 +8,7 @@ import (
"os"
"reflect"
"runtime"
"slices"
"sort"
"strconv"
"strings"
@@ -132,11 +133,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
continue
}
numParallel := int(envconfig.NumParallel())
// TODO (jmorganca): mllama doesn't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if checkMllamaModelFamily(pending.model) && numParallel != 1 {
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains(pending.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
numParallel = 1
slog.Warn("mllama doesn't support parallel requests yet")
slog.Warn("mllama does not currently support parallel requests")
}
for {