mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
fix: model load for unsupported embedding models (#12311)
with #12181, there's now support for embeddings in ollama engine. this is done by mutating the architecture and adding _embed when it detects an embedding model. however this introduced a bug where if an embedding model was run based on an existing ollama engine model without an embedding implementation, e.g. llama4, it will pass the initial arch support check but fail when actually loaded. there's currently two entrypoints to creating a model. previously this second entrypoint was necessary because calling model.New would also load the model. since #11818, this is no longer th case so merge them to reduce complexity
This commit is contained in:
@@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
arch := b.Config().Architecture()
|
m, err := modelForArch(b.Config())
|
||||||
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
|
|
||||||
arch = arch + "_embed"
|
|
||||||
}
|
|
||||||
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := f(b.Config())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
base := Base{b: b, config: m.Config()}
|
base := Base{b: b, config: m.Config()}
|
||||||
|
|
||||||
v := reflect.ValueOf(m)
|
v := reflect.ValueOf(m)
|
||||||
v.Elem().Set(populateFields(base, v.Elem()))
|
v.Elem().Set(populateFields(base, v.Elem()))
|
||||||
return m, nil
|
return m, nil
|
||||||
@@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
meta, err := fsggml.Decode(r, -1)
|
meta, err := fsggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return getTextProcessor(meta.KV())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
m, err := modelForArch(meta.KV())
|
||||||
arch := kv.Architecture()
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
m, err := f(kv)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tp, ok := m.(TextProcessor)
|
tp, ok := m.(TextProcessor)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
return nil, ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
return tp, nil
|
return tp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func modelForArch(c fs.Config) (Model, error) {
|
||||||
|
arch := c.Architecture()
|
||||||
|
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||||
|
arch = arch + "_embed"
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := models[arch]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUnsupportedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
return f(c)
|
||||||
|
}
|
||||||
|
|
||||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseTags(t *testing.T) {
|
func TestParseTags(t *testing.T) {
|
||||||
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTextProcessor(t *testing.T) {
|
func TestModelForArch(t *testing.T) {
|
||||||
tp, err := getTextProcessor(fsggml.KV{})
|
type fakeModel struct {
|
||||||
if err == nil {
|
Model
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
models["dummy"] = func(fs.Config) (Model, error) {
|
type fakeEmbeddingModel struct {
|
||||||
return notTextProcessorModel{}, nil
|
Model
|
||||||
}
|
|
||||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type notTextProcessorModel struct{}
|
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
|
||||||
|
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
cases := []struct {
|
||||||
panic("unimplemented")
|
name string
|
||||||
|
config fs.Config
|
||||||
|
want any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
},
|
||||||
|
want: fakeModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embedding",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
"model.pooling_type": uint32(1),
|
||||||
|
},
|
||||||
|
want: fakeEmbeddingModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "unsupported",
|
||||||
|
},
|
||||||
|
err: ErrUnsupportedModel,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (notTextProcessorModel) Backend() ml.Backend {
|
for _, tt := range cases {
|
||||||
panic("unimplemented")
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := modelForArch(tt.config)
|
||||||
|
if !errors.Is(err, tt.err) {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (notTextProcessorModel) Config() config {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
panic("unimplemented")
|
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user