From 9f3a37fd36bf1c46cc86a47bc5372535f8ee3547 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 18 Sep 2025 16:11:08 -0700 Subject: [PATCH] fix: model load for unsupported embedding models (#12311) with #12181, there's now support for embeddings in ollama engine. this is done by mutating the architecture and adding _embed when it detects an embedding model. however this introduced a bug where if an embedding model was run based on an existing ollama engine model without an embedding implementation, e.g. llama4, it will pass the initial arch support check but fail when actually loaded. there's currently two entrypoints to creating a model. previously this second entrypoint was necessary because calling model.New would also load the model. since #11818, this is no longer th case so merge them to reduce complexity --- model/model.go | 41 ++++++++++------------ model/model_test.go | 84 +++++++++++++++++++++++++++------------------ 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/model/model.go b/model/model.go index 5493a4e6..f3d6bb3d 100644 --- a/model/model.go +++ b/model/model.go @@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return nil, err } - arch := b.Config().Architecture() - if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone { - arch = arch + "_embed" - } - - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - - m, err := f(b.Config()) + m, err := modelForArch(b.Config()) if err != nil { return nil, err } base := Base{b: b, config: m.Config()} - v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) return m, nil @@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } - return getTextProcessor(meta.KV()) -} -func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { - arch := kv.Architecture() - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - m, err := f(kv) + m, err := modelForArch(meta.KV()) if err != nil { return nil, err } + tp, ok := m.(TextProcessor) if !ok { - return nil, fmt.Errorf("%v is not a TextProcessor", m) + return nil, ErrUnsupportedTokenizer } return tp, nil } +func modelForArch(c fs.Config) (Model, error) { + arch := c.Architecture() + if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { + arch = arch + "_embed" + } + + f, ok := models[arch] + if !ok { + return nil, ErrUnsupportedModel + } + + return f(c) +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() diff --git a/model/model_test.go b/model/model_test.go index 020f9ffb..01080ffd 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -1,9 +1,9 @@ package model import ( + "errors" "reflect" "slices" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } -func TestGetTextProcessor(t *testing.T) { - tp, err := getTextProcessor(fsggml.KV{}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "unsupported model architecture") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") +func TestModelForArch(t *testing.T) { + type fakeModel struct { + Model } - models["dummy"] = func(fs.Config) (Model, error) { - return notTextProcessorModel{}, nil + type fakeEmbeddingModel struct { + Model } - tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "not a TextProcessor") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") + + models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil } + models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil } + + cases := []struct { + name string + config fs.Config + want any + err error + }{ + { + name: "model", + config: fsggml.KV{ + "general.architecture": "model", + }, + want: fakeModel{}, + }, + { + name: "embedding", + config: fsggml.KV{ + "general.architecture": "model", + "model.pooling_type": uint32(1), + }, + want: fakeEmbeddingModel{}, + }, + { + name: "unsupported", + config: fsggml.KV{ + "general.architecture": "unsupported", + }, + err: ErrUnsupportedModel, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := modelForArch(tt.config) + if !errors.Is(err, tt.err) { + t.Fatal(err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff) + } + }) } } - -type notTextProcessorModel struct{} - -func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) { - panic("unimplemented") -} - -func (notTextProcessorModel) Backend() ml.Backend { - panic("unimplemented") -} - -func (notTextProcessorModel) Config() config { - panic("unimplemented") -}