fix: model load for unsupported embedding models (#12311)

with #12181, there's now support for embeddings in ollama engine.
this is done by mutating the architecture and adding _embed when it
detects an embedding model. however this introduced a bug where if
an embedding model was run based on an existing ollama engine model
without an embedding implementation, e.g. llama4, it will pass the
initial arch support check but fail when actually loaded.

there's currently two entrypoints to creating a model. previously this
second entrypoint was necessary because calling model.New would also
load the model. since #11818, this is no longer th case so merge them
to reduce complexity
This commit is contained in:
Michael Yang
2025-09-18 16:11:08 -07:00
committed by GitHub
parent 7460259eb3
commit 9f3a37fd36
2 changed files with 70 additions and 55 deletions

View File

@@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return nil, err return nil, err
} }
arch := b.Config().Architecture() m, err := modelForArch(b.Config())
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(b.Config())
if err != nil { if err != nil {
return nil, err return nil, err
} }
base := Base{b: b, config: m.Config()} base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem())) v.Elem().Set(populateFields(base, v.Elem()))
return m, nil return m, nil
@@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
meta, err := fsggml.Decode(r, -1) meta, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { m, err := modelForArch(meta.KV())
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tp, ok := m.(TextProcessor) tp, ok := m.(TextProcessor)
if !ok { if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m) return nil, ErrUnsupportedTokenizer
} }
return tp, nil return tp, nil
} }
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"errors"
"reflect" "reflect"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
) )
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
} }
} }
func TestGetTextProcessor(t *testing.T) { func TestModelForArch(t *testing.T) {
tp, err := getTextProcessor(fsggml.KV{}) type fakeModel struct {
if err == nil { Model
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
} }
models["dummy"] = func(fs.Config) (Model, error) { type fakeEmbeddingModel struct {
return notTextProcessorModel{}, nil Model
} }
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
if err == nil { models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
t.Error("expected error") models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err) cases := []struct {
} else if tp != nil { name string
t.Error("expected nil tp") config fs.Config
want any
err error
}{
{
name: "model",
config: fsggml.KV{
"general.architecture": "model",
},
want: fakeModel{},
},
{
name: "embedding",
config: fsggml.KV{
"general.architecture": "model",
"model.pooling_type": uint32(1),
},
want: fakeEmbeddingModel{},
},
{
name: "unsupported",
config: fsggml.KV{
"general.architecture": "unsupported",
},
err: ErrUnsupportedModel,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := modelForArch(tt.config)
if !errors.Is(err, tt.err) {
t.Fatal(err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
}
})
} }
} }
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
func (notTextProcessorModel) Config() config {
panic("unimplemented")
}