From 333e360422744e92275af2c1c2d5bc039ad97e8f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 16 May 2025 13:40:23 -0700 Subject: [PATCH 01/26] model: handle multiple eos tokens (#10577) * get eos_token_id from generation_config.json * refactor * include both ids and strings in trace * comments * remove special case for gemma3 special vocab (#10743) --- convert/convert.go | 5 +- convert/tokenizer.go | 32 +++++ convert/tokenizer_test.go | 61 +++++++++ llama/llama.go | 4 +- .../{process_text.go => bytepairencoding.go} | 126 +----------------- ..._text_test.go => bytepairencoding_test.go} | 0 model/models/gemma2/model.go | 11 +- model/models/gemma3/model.go | 12 +- model/models/llama/model.go | 10 +- model/models/llama4/model.go | 10 +- model/models/mistral3/model.go | 18 +-- model/models/mllama/model.go | 10 +- model/models/qwen25vl/model.go | 11 +- .../{process_text_spm.go => sentencepiece.go} | 23 +--- ...text_spm_test.go => sentencepiece_test.go} | 0 model/textprocessor.go | 17 +++ model/vocabulary.go | 112 ++++++++++++++++ sample/samplers.go | 2 +- 18 files changed, 282 insertions(+), 182 deletions(-) rename model/{process_text.go => bytepairencoding.go} (66%) rename model/{process_text_test.go => bytepairencoding_test.go} (100%) rename model/{process_text_spm.go => sentencepiece.go} (89%) rename model/{process_text_spm_test.go => sentencepiece_test.go} (100%) create mode 100644 model/textprocessor.go create mode 100644 model/vocabulary.go diff --git a/convert/convert.go b/convert/convert.go index 309b0ce1..4a6df66c 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -53,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV { } for _, sv := range t.SpecialVocabulary { - kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken + kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) + if len(sv.IDs) > 0 { + kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs + } } return kv diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 74e2efed..bedcd4f8 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -110,6 +110,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { + // noop } else if err != nil { return nil, err } else { @@ -171,6 +172,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } } + if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) { + } else if err != nil { + return nil, err + } else { + defer f.Close() + + var p map[string]json.RawMessage + if err := json.NewDecoder(f).Decode(&p); err != nil { + return nil, err + } + + for _, st := range specialTokenTypes { + if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok { + var ids []int32 + if err := json.Unmarshal(bts, &ids); err != nil { + // value is not a list so the existing ID is used + continue + } + + if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool { + return sv.Type == st + }); i >= 0 { + t.SpecialVocabulary[i].IDs = ids + } + } + } + } + return t, nil } @@ -280,6 +309,9 @@ type SpecialVocabulary struct { ID int Content string AddToken bool + + // IDs is populated by generation_config.json + IDs []int32 } func (sv SpecialVocabulary) Key() string { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index c6ef9732..813096fd 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "generation config eos token ids", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 0, + "content": "", + "special": true + }, + { + "id": 1, + "content": "", + "special": true + }, + { + "id": 2, + "content": "", + "special": true + }, + { + "id": 3, + "content": "", + "special": true + } + ], + "model": { + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3 + } + } + }`), + "tokenizer_config.json": strings.NewReader(`{ + "add_bos_token": true, + "add_eos_token": false, + "bos_token": "", + "eos_token": "" + }`), + "generation_config.json": strings.NewReader(`{ + "bos_token_id": 0, + "eos_token_id": [1, 2, 3] + }`), + }), + specialTokenTypes: []string{"pad", "eos", "bos", "unk"}, + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{"", "", "", ""}, + Scores: []float32{0, 1, 2, 3}, + Types: []int32{3, 3, 3, 3}, + }, + SpecialVocabulary: []*SpecialVocabulary{ + {Type: "eos", Content: "", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false}, + {Type: "bos", Content: "", ID: 0, AddToken: true}, + }, + Pre: "default", + }, + }, } for _, tt := range cases { diff --git a/llama/llama.go b/llama/llama.go index 1251be3a..626ea13a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -602,7 +602,7 @@ type Grammar struct { mu sync.Mutex } -func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar { +func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar { cGrammar := C.CString(grammar) defer C.free(unsafe.Pointer(cGrammar)) @@ -622,7 +622,7 @@ func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogToke cEogTokens[i] = C.uint32_t(token) } - g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens))) + g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens))) if g == nil { return nil } diff --git a/model/process_text.go b/model/bytepairencoding.go similarity index 66% rename from model/process_text.go rename to model/bytepairencoding.go index 7b87ccc3..6bb9a003 100644 --- a/model/process_text.go +++ b/model/bytepairencoding.go @@ -5,116 +5,13 @@ import ( "context" "iter" "log/slog" - "slices" "strings" - "sync" "github.com/dlclark/regexp2" heap "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/ollama/ollama/logutil" ) -type Special int32 - -const ( - SpecialBOS Special = iota - SpecialEOS -) - -const ( - TOKEN_TYPE_NORMAL = iota + 1 - TOKEN_TYPE_UNKNOWN - TOKEN_TYPE_CONTROL - TOKEN_TYPE_USER_DEFINED - TOKEN_TYPE_UNUSED - TOKEN_TYPE_BYTE -) - -type TextProcessor interface { - Encode(s string, addSpecial bool) ([]int32, error) - Decode([]int32) (string, error) - Is(int32, Special) bool - Vocabulary() *Vocabulary -} - -type Vocabulary struct { - Values []string - Types []int32 - Scores []float32 - Merges []string - - BOS, EOS, EOT int32 - AddBOS, AddEOS, AddEOT bool - - specialOnce sync.Once - special []string - - valuesOnce sync.Once - values map[string]int32 - - mergeOnce sync.Once - merge map[string]int32 -} - -func (v *Vocabulary) Is(id int32, special Special) bool { - switch special { - case SpecialBOS: - return id == v.BOS - case SpecialEOS: - return id == v.EOS || id == v.EOT - default: - return false - } -} - -func (v *Vocabulary) Encode(s string) int32 { - v.valuesOnce.Do(func() { - v.values = make(map[string]int32, len(v.Values)) - for i, value := range v.Values { - v.values[value] = int32(i) - } - }) - - if id, ok := v.values[s]; ok { - return id - } - - return -1 -} - -func (v *Vocabulary) Decode(id int32) string { - return v.Values[id] -} - -func (v *Vocabulary) SpecialVocabulary() []string { - v.specialOnce.Do(func() { - for i := range v.Values { - if slices.Contains([]int{105, 106}, i) { - v.special = append(v.special, v.Values[i]) - } else if v.Types[i] == TOKEN_TYPE_CONTROL { - v.special = append(v.special, v.Values[i]) - } - } - }) - - return v.special -} - -func (v *Vocabulary) Merge(left, right string) int { - v.mergeOnce.Do(func() { - v.merge = make(map[string]int32, len(v.Merges)) - for i, merge := range v.Merges { - v.merge[merge] = int32(i) - } - }) - - if id, ok := v.merge[left+" "+right]; ok { - return int(id) - } - - return -1 -} - type BytePairEncoding struct { pre *regexp2.Regexp vocab *Vocabulary @@ -304,27 +201,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if bpe.vocab.AddBOS { - if ids[0] == bpe.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS) - ids = append([]int32{bpe.vocab.BOS}, ids...) - } - - if bpe.vocab.AddEOS { - if ids[len(ids)-1] == bpe.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS) - ids = append(ids, bpe.vocab.EOS) - } + ids = bpe.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -352,6 +234,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_test.go b/model/bytepairencoding_test.go similarity index 100% rename from model/process_text_test.go rename to model/bytepairencoding_test.go diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 3156b006..a87534c5 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -43,10 +43,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index d53eb6cc..89d1788e 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -60,12 +60,16 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(1), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(106), - AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index c75d7eb2..6e214f0f 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -43,13 +43,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index c94aa72f..af5173a1 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -40,13 +40,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index b93882a9..0d384b94 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -37,25 +37,25 @@ func New(c fs.Config) (model.Model, error) { } m := &Model{ - TextModel: textModel, - VisionModel: newVisionModel(c), - ImageProcessor: newImageProcessor(c), - MultiModalProjector: newMultiModalProjector(c), BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), + TextModel: textModel, + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), } m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 15571d9c..547e2cb3 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -38,13 +38,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 48655450..7de9b6eb 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -34,12 +34,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), TextModel: NewTextModel(c), diff --git a/model/process_text_spm.go b/model/sentencepiece.go similarity index 89% rename from model/process_text_spm.go rename to model/sentencepiece.go index b1cff7d2..7d725f04 100644 --- a/model/process_text_spm.go +++ b/model/sentencepiece.go @@ -182,27 +182,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if spm.vocab.AddBOS { - if ids[0] == spm.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS) - ids = append([]int32{spm.vocab.BOS}, ids...) - } - - if spm.vocab.AddEOS { - if ids[len(ids)-1] == spm.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS) - ids = append(ids, spm.vocab.EOS) - } + ids = spm.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -261,6 +246,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_spm_test.go b/model/sentencepiece_test.go similarity index 100% rename from model/process_text_spm_test.go rename to model/sentencepiece_test.go diff --git a/model/textprocessor.go b/model/textprocessor.go new file mode 100644 index 00000000..4a36f235 --- /dev/null +++ b/model/textprocessor.go @@ -0,0 +1,17 @@ +package model + +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + +type TextProcessor interface { + Encode(s string, addSpecial bool) ([]int32, error) + Decode([]int32) (string, error) + Is(int32, Special) bool + Vocabulary() *Vocabulary +} diff --git a/model/vocabulary.go b/model/vocabulary.go new file mode 100644 index 00000000..24adbaca --- /dev/null +++ b/model/vocabulary.go @@ -0,0 +1,112 @@ +package model + +import ( + "log/slog" + "slices" + "sync" +) + +type Special int32 + +const ( + SpecialBOS Special = iota + SpecialEOS +) + +type Vocabulary struct { + Values []string + Types []int32 + Scores []float32 + Merges []string + + BOS, EOS []int32 + AddBOS, AddEOS bool + + specialOnce sync.Once + special []string + + valuesOnce sync.Once + values map[string]int32 + + mergeOnce sync.Once + merge map[string]int32 +} + +func (v *Vocabulary) Is(id int32, special Special) bool { + switch special { + case SpecialBOS: + return slices.Contains(v.BOS, id) + case SpecialEOS: + return slices.Contains(v.EOS, id) + default: + return false + } +} + +func (v *Vocabulary) addSpecials(ids []int32) []int32 { + if v.AddBOS && len(v.BOS) > 0 { + if slices.Contains(v.BOS, ids[0]) { + slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) + } + + slog.Debug("adding bos token to prompt", "id", v.BOS) + ids = append([]int32{v.BOS[0]}, ids...) + } + + if v.AddEOS && len(v.EOS) > 0 { + if slices.Contains(v.BOS, ids[len(ids)-1]) { + slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) + } + + slog.Debug("adding eos token to prompt", "id", v.EOS) + ids = append(ids, v.EOS[0]) + } + + return ids +} + +func (v *Vocabulary) Encode(s string) int32 { + v.valuesOnce.Do(func() { + v.values = make(map[string]int32, len(v.Values)) + for i, value := range v.Values { + v.values[value] = int32(i) + } + }) + + if id, ok := v.values[s]; ok { + return id + } + + return -1 +} + +func (v *Vocabulary) Decode(id int32) string { + return v.Values[id] +} + +func (v *Vocabulary) SpecialVocabulary() []string { + v.specialOnce.Do(func() { + for i := range v.Values { + if v.Types[i] == TOKEN_TYPE_CONTROL { + v.special = append(v.special, v.Values[i]) + } + } + }) + + return v.special +} + +func (v *Vocabulary) Merge(left, right string) int { + v.mergeOnce.Do(func() { + v.merge = make(map[string]int32, len(v.Merges)) + for i, merge := range v.Merges { + v.merge[merge] = int32(i) + } + }) + + if id, ok := v.merge[left+" "+right]; ok { + return int(id) + } + + return -1 +} diff --git a/sample/samplers.go b/sample/samplers.go index f0846c8d..d395650d 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa vocabIds[i] = uint32(i) } - grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)}) + grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS) if grammar == nil { return nil, errors.New("sample: failed to initialize grammar") } From 7edfdd2f5f48a7be035cec23b4acd12f7c112e1c Mon Sep 17 00:00:00 2001 From: Ronald Wilson <57818133+ronxldwilson@users.noreply.github.com> Date: Mon, 19 May 2025 01:13:22 +0530 Subject: [PATCH 02/26] readme: add TinyNotepad to community integrations (#10763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds Tiny Notepad, a lightweight, notepad-like interface to chat with local LLMs via Ollama. - It’s designed as a simple, distraction-free alternative. - The app supports basic note-taking, timestamped logs, and model parameter controls. - Built with Tkinter, it runs entirely offline and available via PyPI. Aims to provide a lightweight easy to run and install interface for ollama. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 7cb61ac2..6a4815c1 100644 --- a/README.md +++ b/README.md @@ -405,6 +405,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) - [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers) +- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI) ### Cloud From a2cc8571c5b2b8f77a8a5e2f65cb7aaa56482dc4 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 13 May 2025 13:04:20 -0700 Subject: [PATCH 03/26] llm: Consistently track unassigned model data In some cases, if we fail to assign a piece of the model to a GPU then we lose track of this data. Although it doesn't change the memory allocation, it does affect the total size of the model reported by tools such as ollama ps (and also the percent offloaded). This makes it look like setting num_gpu isn't reflected in ollama ps, which isn't true but the offloading percent may appear to not change. Spreading the model across more GPUs will continue to impact the reported total size of the model. --- llm/memory.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index b5a8dd5c..46472330 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -216,6 +216,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin if len(gpusWithSpace) > 0 { gpuZeroID = gpusWithSpace[0].i gpuAllocations[gpuZeroID] += gpuZeroOverhead + } else { + overflow += gpuZeroOverhead } // For all the layers, find where they can fit on the GPU(s) @@ -256,15 +258,17 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // Determine if we need to consider output then find where it fits - if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) { - for j := len(gpusWithSpace); j > 0; j-- { - g := gpusWithSpace[layerCount%j] - used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) - if g.g.FreeMemory > overhead+used+memoryLayerOutput { - gpuAllocations[g.i] += memoryLayerOutput - layerCounts[g.i]++ - layerCount++ - break + if memoryLayerOutput > 0 { + if opts.NumGPU < 0 || layerCount < opts.NumGPU { + for j := len(gpusWithSpace); j > 0; j-- { + g := gpusWithSpace[layerCount%j] + used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) + if g.g.FreeMemory > overhead+used+memoryLayerOutput { + gpuAllocations[g.i] += memoryLayerOutput + layerCounts[g.i]++ + layerCount++ + break + } } } From d75557747357bfb3afd441a0cc207ec944bd3a18 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 13 May 2025 11:36:52 -0700 Subject: [PATCH 04/26] llm: Estimate projector memory correctly for Ollama engine The Llama engine always places vision projectors on the first GPU if one exists. However, the Ollama engine groups it with the output layer, which means the projector is only offloaded if all other layers are offloaded. The memory estimation code always assumes the former layout - this changes it to use the correct layout based on the engine. This addresses two impacts of the current behavior: - In multi-GPU setups, we can crash with OOM errors when we try to allocate memory on a full GPU while another still has space. - If the vision projector is large, it may prevent us from offloading anything when we could have fit some of the text layers. --- llm/memory.go | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index 46472330..32774d28 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -85,8 +85,11 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin var graphOffload uint64 // Projectors loaded into GPU0 only - var projectorWeights uint64 - var projectorGraph uint64 + var llamaEngineProjectorWeights uint64 + + // Projectors loaded with output layer + var ollamaEngineProjectorWeights uint64 + var ollamaEngineProjectorGraph uint64 // Conditional output size on GPU 0 var memoryLayerOutput uint64 @@ -111,14 +114,14 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList) for _, projector := range projectors { - weight := projectorMemoryRequirements(projector) - projectorWeights += weight + llamaEngineProjectorWeights += projectorMemoryRequirements(projector) // multimodal models require at least 2048 context opts.NumCtx = max(opts.NumCtx, 2048) } - if projectorWeights == 0 && projectorGraph == 0 { - projectorWeights, projectorGraph = f.VisionGraphSize() + if llamaEngineProjectorWeights == 0 { + ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize() + opts.NumCtx = max(opts.NumCtx, 2048) } layers := f.Tensors().GroupLayers() @@ -163,6 +166,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin graphFullOffload = graphPartialOffload } + // Output layer handled at the end if we have space if layer, ok := layers["output_norm"]; ok { memoryLayerOutput += layer.Size() } @@ -172,8 +176,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin memoryLayerOutput += layer.Size() } - // Output layer handled at the end if we have space - gpuZeroOverhead := projectorWeights + projectorGraph + gpuZeroOverhead := llamaEngineProjectorWeights // Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer var layerCount int @@ -258,13 +261,14 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // Determine if we need to consider output then find where it fits - if memoryLayerOutput > 0 { + memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph + if memoryLastLayer > 0 { if opts.NumGPU < 0 || layerCount < opts.NumGPU { for j := len(gpusWithSpace); j > 0; j-- { g := gpusWithSpace[layerCount%j] used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) - if g.g.FreeMemory > overhead+used+memoryLayerOutput { - gpuAllocations[g.i] += memoryLayerOutput + if g.g.FreeMemory > overhead+used+memoryLastLayer { + gpuAllocations[g.i] += memoryLastLayer layerCounts[g.i]++ layerCount++ break @@ -274,7 +278,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin if layerCount < int(f.KV().BlockCount())+1 { fullyLoaded = false - overflow += memoryLayerOutput + overflow += memoryLastLayer } } @@ -332,8 +336,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin memoryLayerOutput: memoryLayerOutput, graphFullOffload: graphFullOffload, graphPartialOffload: graphPartialOffload, - projectorWeights: projectorWeights, - projectorGraph: projectorGraph, + projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights, + projectorGraph: ollamaEngineProjectorGraph, } if gpus[0].Library == "cpu" { From 94ab428e3f77fdd9d9c833b369bb40980c65049a Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 17 Apr 2025 13:42:40 -0700 Subject: [PATCH 05/26] ggml: Seperate tensor load from backend creation Currently, when the backend is created, the tensors are loaded at the same time, which is a slow operation. This separates them to be two steps: - Create backend, including enumerating tensors and memory allocation - Loading tensor data This allows more flexibility in managing model loading. --- convert/convert_test.go | 4 +- fs/ggml/ggml.go | 14 +-- fs/ggml/gguf_test.go | 2 +- llm/memory.go | 2 +- llm/server.go | 2 +- ml/backend.go | 14 ++- ml/backend/ggml/ggml.go | 163 +++++++++++++++++++--------------- model/model.go | 12 +-- runner/ollamarunner/runner.go | 13 ++- server/create.go | 12 +-- server/images.go | 2 +- server/model.go | 2 +- server/quantization_test.go | 4 +- 13 files changed, 131 insertions(+), 115 deletions(-) diff --git a/convert/convert_test.go b/convert/convert_test.go index b9db6fa1..105fbb3d 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) { } t.Cleanup(func() { r.Close() }) - m, _, err := ggml.Decode(r, -1) + m, err := ggml.Decode(r, -1) if err != nil { t.Fatal(err) } @@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) { } defer r.Close() - m, _, err := ggml.Decode(r, -1) + m, err := ggml.Decode(r, -1) if err != nil { t.Fatal(err) } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 8c0a2ae5..aa85aec2 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -15,6 +15,7 @@ import ( type GGML struct { container model + Length int64 } type model interface { @@ -386,12 +387,12 @@ func DetectContentType(b []byte) string { // // It collects array values for arrays with a size less than or equal to // maxArraySize. If the maxArraySize is negative, all arrays are collected. -func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { +func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { rs = bufioutil.NewBufferedSeeker(rs, 32<<10) var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { - return nil, 0, err + return nil, err } var c container @@ -401,24 +402,25 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { case FILE_MAGIC_GGUF_BE: c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} default: - return nil, 0, errors.New("invalid file magic") + return nil, errors.New("invalid file magic") } model, err := c.Decode(rs) if err != nil { - return nil, 0, err + return nil, err } offset, err := rs.Seek(0, io.SeekCurrent) if err != nil { - return nil, 0, err + return nil, err } // final model type return &GGML{ container: c, model: model, - }, offset, nil + Length: offset, + }, nil } func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go index 10d3b684..0e071800 100644 --- a/fs/ggml/gguf_test.go +++ b/fs/ggml/gguf_test.go @@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) { } defer r.Close() - ff, _, err := Decode(r, 0) + ff, err := Decode(r, 0) if err != nil { t.Fatal(err) } diff --git a/llm/memory.go b/llm/memory.go index 32774d28..e9ed1738 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -423,7 +423,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) { } defer file.Close() - ggml, _, err := ggml.Decode(file, 1024) + ggml, err := ggml.Decode(file, 1024) if err != nil { return 0 } diff --git a/llm/server.go b/llm/server.go index c07315fa..4abb569f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { } defer f.Close() - ggml, _, err := ggml.Decode(f, maxArraySize) + ggml, err := ggml.Decode(f, maxArraySize) return ggml, err } diff --git a/ml/backend.go b/ml/backend.go index cb32d818..2599d774 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "fmt" "math" - "os" "slices" "strconv" "strings" @@ -15,6 +14,7 @@ import ( ) type Backend interface { + Load(ctx context.Context, progress func(float32)) error Config() fs.Config Get(name string) Tensor NewContext() Context @@ -52,10 +52,6 @@ type CacheConfig struct { // BackendParams controls how the backend loads and executes models type BackendParams struct { - // Progress is a callback function that allows reporting percentage completion - // of model loading - Progress func(float32) - // NumThreads sets the number of threads to use if running on the CPU NumThreads int @@ -72,9 +68,9 @@ type BackendParams struct { FlashAttention bool } -var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error)) +var backends = make(map[string]func(string, BackendParams) (Backend, error)) -func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) { +func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { if _, ok := backends[name]; ok { panic("backend: backend already registered") } @@ -82,9 +78,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam backends[name] = f } -func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) { +func NewBackend(modelPath string, params BackendParams) (Backend, error) { if backend, ok := backends["ggml"]; ok { - return backend(ctx, f, params) + return backend(modelPath, params) } return nil, fmt.Errorf("unsupported backend") diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2821ad11..f0b26b2f 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -44,8 +44,15 @@ func devices() []*C.struct_ggml_backend_device { } type Backend struct { + // modelPath is the location of the model data + modelPath string + meta *fsggml.GGML + // tensorLoadTargets maps from the name of the tensor in the file + // to the name that is used by the model definition + tensorLoadTargets map[string][]string + sched *C.struct_ggml_backend_sched schedBackends []*C.struct_ggml_backend schedBufts []*C.struct_ggml_backend_buffer_type @@ -64,8 +71,14 @@ type Backend struct { maxGraphNodes int } -func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { - meta, n, err := fsggml.Decode(r, -1) +func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { + r, err := os.Open(modelPath) + if err != nil { + return nil, err + } + defer r.Close() + + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } @@ -307,73 +320,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } } - var doneBytes atomic.Uint64 - totalBytes := uint64(n) - meta.Tensors().Offset - - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(runtime.GOMAXPROCS(0)) - for _, t := range meta.Tensors().Items() { - t := t - g.Go(func() error { - tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name]))) - for i := range tts { - target := targets[t.Name][i] - if target == "" { - target = t.Name - } - - tt, ok := tensors[target] - if !ok { - return fmt.Errorf("unassigned tensor: %s", t.Name) - } - - tts[i] = tt - } - - // Create a new FD for each goroutine so that each FD is read sequentially, rather than - // seeking around within an FD shared between all goroutines. - file, err := os.Open(r.Name()) - if err != nil { - slog.Warn("file open error", "file", r.Name(), "error", err) - return err - } - defer file.Close() - sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size())) - bts := make([]byte, 128*format.KibiByte) - - var s uint64 - for s < t.Size() { - // Stop if either the parent context has been canceled or if any of the other tensors returned an error - if err := ctx.Err(); err != nil { - return err - } - - n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))]) - if err != nil { - slog.Warn("file read error", "file", r.Name(), "error", err) - return err - } - - for _, tt := range tts { - C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n)) - } - - s += uint64(n) - - if params.Progress != nil { - done := doneBytes.Add(uint64(n)) - params.Progress(float32(done) / float32(totalBytes)) - } - } - - return nil - }) - } - - if err := g.Wait(); err != nil { - return nil, err - } - // map devices to backend buffer types so new tensors can be assigned to the correct device deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type) @@ -397,9 +343,11 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) return &Backend{ - flashAttention: params.FlashAttention, - meta: meta, - tensors: tensors, + modelPath: modelPath, + flashAttention: params.FlashAttention, + meta: meta, + tensorLoadTargets: targets, + tensors: tensors, sched: C.ggml_backend_sched_new( (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), @@ -426,6 +374,77 @@ func init() { ml.RegisterBackend("ggml", New) } +func (b *Backend) Load(ctx context.Context, progress func(float32)) error { + var doneBytes atomic.Uint64 + totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) + for _, t := range b.meta.Tensors().Items() { + t := t + g.Go(func() error { + tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name]))) + for i := range tts { + target := b.tensorLoadTargets[t.Name][i] + if target == "" { + target = t.Name + } + + tt, ok := b.tensors[target] + if !ok { + return fmt.Errorf("unassigned tensor: %s", t.Name) + } + + tts[i] = tt + } + + // Create a new FD for each goroutine so that each FD is read sequentially, rather than + // seeking around within an FD shared between all goroutines. + file, err := os.Open(b.modelPath) + if err != nil { + slog.Warn("file open error", "file", b.modelPath, "error", err) + return err + } + defer file.Close() + sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size())) + bts := make([]byte, 128*format.KibiByte) + + var s uint64 + for s < t.Size() { + // Stop if either the parent context has been canceled or if any of the other tensors returned an error + if err := ctx.Err(); err != nil { + return err + } + + n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))]) + if err != nil { + slog.Warn("file read error", "file", b.modelPath, "error", err) + return err + } + + for _, tt := range tts { + C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n)) + } + + s += uint64(n) + + if progress != nil { + done := doneBytes.Add(uint64(n)) + progress(float32(done) / float32(totalBytes)) + } + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return err + } + + return nil +} + func (b *Backend) Config() fs.Config { return b.meta.KV() } diff --git a/model/model.go b/model/model.go index 98381c90..39b68db1 100644 --- a/model/model.go +++ b/model/model.go @@ -98,14 +98,8 @@ func Register(name string, f func(fs.Config) (Model, error)) { } // New initializes a new model instance with the provided configuration based on the metadata in the model file -func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) { - r, err := os.Open(modelPath) - if err != nil { - return nil, err - } - defer r.Close() - - b, err := ml.NewBackend(ctx, r, params) +func New(modelPath string, params ml.BackendParams) (Model, error) { + b, err := ml.NewBackend(modelPath, params) if err != nil { return nil, err } @@ -134,7 +128,7 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() - meta, _, err := fsggml.Decode(r, -1) + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index cd42d434..a488a104 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -845,7 +845,7 @@ func (s *Server) loadModel( multiUserCache bool, ) { var err error - s.model, err = model.New(ctx, mpath, params) + s.model, err = model.New(mpath, params) if err != nil { panic(err) } @@ -874,6 +874,14 @@ func (s *Server) loadModel( panic(err) } + err = s.model.Backend().Load(ctx, + func(progress float32) { + s.progress = progress + }) + if err != nil { + panic(err) + } + s.status = llm.ServerStatusReady s.ready.Done() } @@ -928,9 +936,6 @@ func Execute(args []string) error { } params := ml.BackendParams{ - Progress: func(progress float32) { - server.progress = progress - }, NumThreads: *threads, NumGPULayers: *numGPULayers, MainGPU: *mainGPU, diff --git a/server/create.go b/server/create.go index 68e003df..82856c4d 100644 --- a/server/create.go +++ b/server/create.go @@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is } defer bin.Close() - f, _, err := ggml.Decode(bin, -1) + f, err := ggml.Decode(bin, -1) if err != nil { return nil, err } @@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr return nil, err } - f, _, err := ggml.Decode(temp, 1024) + f, err := ggml.Decode(temp, 1024) if err != nil { slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) return nil, err @@ -508,7 +508,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML var offset int64 for offset < stat.Size() { - f, n, err := ggml.Decode(blob, 1024) + f, err := ggml.Decode(blob, 1024) if errors.Is(err, io.EOF) { break } else if err != nil { @@ -523,7 +523,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML } var layer Layer - if digest != "" && n == stat.Size() && offset == 0 { + if digest != "" && f.Length == stat.Size() && offset == 0 { layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) if err != nil { slog.Debug("could not create new layer from layer", "error", err) @@ -533,14 +533,14 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) if layer.Digest == "" { - layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype) + layer, err = NewLayer(io.NewSectionReader(blob, offset, f.Length), mediatype) if err != nil { return nil, err } } layers = append(layers, &layerGGML{layer, f}) - offset = n + offset = f.Length } return detectChatTemplate(layers) diff --git a/server/images.go b/server/images.go index 352f10f2..a69e2a9f 100644 --- a/server/images.go +++ b/server/images.go @@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability { if err == nil { defer r.Close() - f, _, err := ggml.Decode(r, 1024) + f, err := ggml.Decode(r, 1024) if err == nil { if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { capabilities = append(capabilities, model.CapabilityEmbedding) diff --git a/server/model.go b/server/model.go index 2149ff85..6b5439a4 100644 --- a/server/model.go +++ b/server/model.go @@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe } defer blob.Close() - f, _, err := ggml.Decode(blob, -1) + f, err := ggml.Decode(blob, -1) if err != nil { return nil, err } diff --git a/server/quantization_test.go b/server/quantization_test.go index 495297df..4f717c2c 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) { t.Fatal(err.Error()) } defer fp.Close() - meta, _, err := fsggml.Decode(fp, -1) + meta, err := fsggml.Decode(fp, -1) if err != nil { t.Fatal(err.Error()) } @@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) { t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) } defer fpNew.Close() - newMeta, _, err := fsggml.Decode(fpNew, -1) + newMeta, err := fsggml.Decode(fpNew, -1) if err != nil { t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) } From 1a0cfd080a2d3e65519c241b7561bf5aa49468ff Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 19 May 2025 13:54:54 -0700 Subject: [PATCH 06/26] avoid kv truncation during create (#10761) --- server/create.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/create.go b/server/create.go index 82856c4d..9488de35 100644 --- a/server/create.go +++ b/server/create.go @@ -508,7 +508,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML var offset int64 for offset < stat.Size() { - f, err := ggml.Decode(blob, 1024) + f, err := ggml.Decode(blob, -1) if errors.Is(err, io.EOF) { break } else if err != nil { From 3fe74fba42b8d496a1ab3e8298bdc9b8ffb0f336 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 19 May 2025 11:40:44 -0700 Subject: [PATCH 07/26] llm: Use first layer as memory buffer in estimation This is a partial revert of 0478d44 "Fixed over vram allcation dure to small initial layer sizes." Previously we used the size of the first layer as an extra reserved amount of space to buffer our memory estimates. The above commit changed this to use the largest layer. However, this had performance impacts on more models than the original commit was trying to fix. There is just a heuristic without an ideal solution so this goes back to the historic behavior. Fixes: #10765, #10756, #10752, #10726 --- llm/memory.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index e9ed1738..05b3b2fd 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -1,12 +1,9 @@ package llm import ( - "cmp" "fmt" "log/slog" - "maps" "os" - "slices" "strconv" "strings" @@ -125,10 +122,12 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } layers := f.Tensors().GroupLayers() - // add one layer (chosing the max layer) worth of memory as a buffer - layerSize = slices.MaxFunc(slices.Collect(maps.Values(layers)), func(a, b ggml.Layer) int { - return cmp.Compare(a.Size(), b.Size()) - }).Size() + // add one layer worth of memory as a buffer + if blk0, ok := layers["blk.0"]; ok { + layerSize = blk0.Size() + } else { + slog.Warn("model missing blk.0 layer size") + } var kvct string if envconfig.FlashAttention() && From ff180c3466e7f3ee21658465958c9ece6de2d5c0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 19 May 2025 15:06:35 -0700 Subject: [PATCH 08/26] fix llama and mistral3 models (#10774) * fix llama model * fix mistral3.1 model do not set default vision layers --- model/models/llama/model.go | 18 +++++++----------- model/models/mistral3/model.go | 7 +------ model/models/mistral3/model_text.go | 18 ++++-------------- model/models/mistral3/model_vision.go | 2 +- 4 files changed, 13 insertions(+), 32 deletions(-) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6e214f0f..9b85f1c7 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -1,9 +1,8 @@ package llama import ( - "fmt" + "cmp" "math" - "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -14,9 +13,9 @@ import ( ) type Options struct { - hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads, headDim int + eps, ropeBase, ropeScale float32 + ropeDim uint32 } type Model struct { @@ -32,10 +31,6 @@ type Model struct { } func New(c fs.Config) (model.Model, error) { - if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { - return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) - } - m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), @@ -57,6 +52,7 @@ func New(c fs.Config) (model.Model, error) { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), @@ -79,7 +75,7 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) @@ -95,7 +91,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten scaleFactor := 1.0 / math.Sqrt(float64(headDim)) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 0d384b94..dd01a587 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -31,11 +31,6 @@ var _ model.MultimodalProcessor = (*Model)(nil) var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { - textModel, err := NewTextModel(c) - if err != nil { - return nil, err - } - m := &Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), @@ -52,7 +47,7 @@ func New(c fs.Config) (model.Model, error) { ), }, ), - TextModel: textModel, + TextModel: newTextModel(c), VisionModel: newVisionModel(c), ImageProcessor: newImageProcessor(c), MultiModalProjector: newMultiModalProjector(c), diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 17939800..57e2a40a 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -1,9 +1,8 @@ package mistral3 import ( - "fmt" + "cmp" "math" - "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -37,10 +36,7 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) ropeType := uint32(0) - headDim := opts.headDim - if headDim == 0 { - headDim = opts.hiddenSize / opts.numHeads - } + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -125,12 +121,8 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor return m.Output.Forward(ctx, hiddenState) } -func NewTextModel(c fs.Config) (*TextModel, error) { - if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { - return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) - } - - textModel := &TextModel{ +func newTextModel(c fs.Config) *TextModel { + return &TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ hiddenSize: int(c.Uint("embedding_length")), @@ -143,6 +135,4 @@ func NewTextModel(c fs.Config) (*TextModel, error) { ropeDim: c.Uint("rope.dimension_count"), }, } - - return textModel, nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 469dc40c..24541004 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -170,7 +170,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { func newVisionModel(c fs.Config) *VisionModel { return &VisionModel{ - Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), VisionModelOptions: &VisionModelOptions{ hiddenSize: int(c.Uint("vision.embedding_length", 1024)), numHeads: int(c.Uint("vision.attention.head_count", 16)), From e6a800ca11cb52b24fa5afc5245ed1277811fbe9 Mon Sep 17 00:00:00 2001 From: DarkCaster Date: Tue, 20 May 2025 20:41:15 +0300 Subject: [PATCH 09/26] llama: fix incorrect initialization of C.struct_common_sampler_cparams.penalty_present (#10779) --- llama/llama.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/llama.go b/llama/llama.go index 626ea13a..adee6f63 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -544,7 +544,7 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, cparams.penalty_last_n = C.int32_t(params.RepeatLastN) cparams.penalty_repeat = C.float(params.PenaltyRepeat) cparams.penalty_freq = C.float(params.PenaltyFreq) - cparams.penalty_present = C.float(params.PenaltyFreq) + cparams.penalty_present = C.float(params.PenaltyPresent) cparams.seed = C.uint32_t(params.Seed) grammar := C.CString(params.Grammar) From 9ed8bf14cb885509281d63731cda16637a7e0bd2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 20 May 2025 15:51:08 -0700 Subject: [PATCH 10/26] ml: add more rope options (#10775) --- ml/backend.go | 16 -------------- ml/backend/ggml/ggml.go | 26 ++++++----------------- ml/nn/fast/rope.go | 21 ++++++++++++++++++ ml/nn/rope/rope.go | 33 +++++++++++++++++++++++++++++ model/models/gemma2/model.go | 9 ++++---- model/models/gemma3/model_text.go | 9 ++++---- model/models/llama/model.go | 17 ++++++++------- model/models/llama4/model_text.go | 8 ++++--- model/models/mistral3/model_text.go | 16 +++++++------- model/models/mllama/model_text.go | 13 ++++++------ model/models/qwen25vl/model_text.go | 31 ++++++++++++++------------- 11 files changed, 116 insertions(+), 83 deletions(-) create mode 100644 ml/nn/fast/rope.go create mode 100644 ml/nn/rope/rope.go diff --git a/ml/backend.go b/ml/backend.go index 2599d774..70497e6e 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -115,21 +115,6 @@ type Context interface { Layer(int) Context } -// RopeOptions contains optional parameters for RoPE function -type RopeOptions struct { - OriginalContextLen uint32 -} - -// RopeOption defines a function that modifies RopeOpts -type RopeOption func(*RopeOptions) - -// WithContextLen sets a custom context length -func WithContextLen(len uint32) RopeOption { - return func(opts *RopeOptions) { - opts.OriginalContextLen = len - } -} - type Tensor interface { Dim(n int) int Stride(n int) int @@ -155,7 +140,6 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Sin(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f0b26b2f..3ae50ace 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "github.com/ollama/ollama/ml/nn/rope" "golang.org/x/sync/errgroup" ) @@ -1074,28 +1075,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -const ( - ropeTypeNorm C.int = 0 - ropeTypeNeox C.int = 2 - ropeTypeMrope C.int = 8 - ropeTypeVision C.int = 24 -) - -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor { +func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor { // Default options - opts := &ml.RopeOptions{ - OriginalContextLen: 131072, - } + opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}} // Apply any provided options for _, option := range options { option(opts) } - if ropeFactors == nil { - ropeFactors = &Tensor{b: t.b} - } - dequant := t.t if C.ggml_is_quantized(t.t._type) { dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) @@ -1106,11 +1094,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, - positionIDs.(*Tensor).t, - ropeFactors.(*Tensor).t, + positions.(*Tensor).t, + opts.Factors.(*Tensor).t, C.int(ropeDim), - C.int(ropeType), - C.int(opts.OriginalContextLen), + C.int(opts.Type), + C.int(opts.OriginalContextLength), C.float(ropeBase), C.float(ropeScale), C.float(0.0), diff --git a/ml/nn/fast/rope.go b/ml/nn/fast/rope.go new file mode 100644 index 00000000..b45938eb --- /dev/null +++ b/ml/nn/fast/rope.go @@ -0,0 +1,21 @@ +// fast provides implementations of fast (fused) operations for increased performance. +package fast + +import ( + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/rope" +) + +// fastRoPE is an interface for tensors that support fast rotary positional embedding. +type fastRoPE interface { + RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor +} + +// RoPE applies rotary positional embedding to tensor `t`. +func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor { + if t, ok := t.(fastRoPE); ok { + return t.RoPE(ctx, positions, dim, base, scale, options...) + } + + panic("RoPE not implemented for this tensor type") +} diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/rope.go new file mode 100644 index 00000000..b0c00a5b --- /dev/null +++ b/ml/nn/rope/rope.go @@ -0,0 +1,33 @@ +package rope + +import "github.com/ollama/ollama/ml" + +// Options contains optional parameters for RoPE function +type Options struct { + OriginalContextLength int + Type int + Factors ml.Tensor +} + +// WithOriginalContextLength sets a custom context length +func WithOriginalContextLength(n int) func(*Options) { + return func(opts *Options) { + opts.OriginalContextLength = n + } +} + +// WithType sets RoPE type to NeoX +func WithTypeNeoX() func(*Options) { + return func(opts *Options) { + opts.Type = 2 + } +} + +// WithFactors sets custom rope factors +func WithFactors(factors ml.Tensor) func(*Options) { + return func(opts *Options) { + if factors != nil { + opts.Factors = factors + } + } +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index a87534c5..3c5a7ea5 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) @@ -83,11 +85,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -127,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index a40614af..70d7797e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -73,7 +75,6 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) ropeBase := opts.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { @@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9b85f1c7..7b475512 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -8,14 +8,16 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type Options struct { - hiddenSize, numHeads, numKVHeads, headDim int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 } type Model struct { @@ -53,10 +55,10 @@ func New(c fs.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } @@ -76,15 +78,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index d98587bd..829c805c 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) - key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -250,5 +252,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 57e2a40a..19c36f9f 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,13 +8,14 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - hiddenSize, numHeads, numKVHeads, headDim int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -35,16 +36,15 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(0) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -129,10 +129,10 @@ func newTextModel(c fs.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 9bd414af..47a518ce 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" ) type TextSelfAttention struct { @@ -21,15 +23,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int + ropeDim int eps, ropeBase, ropeScale float32 - ropeDim uint32 crossAttentionLayers []int32 } @@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 800fd961..4b6bc166 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -7,13 +7,15 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - ctxLen, hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim, defaultContextLen uint32 + hiddenSize, numHeads, numKVHeads int + ropeDim, originalContextLength int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel { m := TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ - ctxLen: int(c.Uint("context_length")), - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count", 128), - defaultContextLen: c.Uint("context_length", 128000), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count", 128)), + originalContextLength: int(c.Uint("context_length", 128000)), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), }, } @@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation From 69b2fe9282323a57cd3557bed9b598b465d1b3a6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 09:39:20 -0700 Subject: [PATCH 11/26] fix: qwen25vl assign samebatch in multimodal input (#10789) setting samebatch on the vision start token is problematic because it will be shared with other inputs that also use images. this will cause the input to be cached and the runner will not see SameBatch. SameBatch will also be incorrect since it may be for a different image. assigning samebatch to the input tokens resolves this by ensure it's assigned correctly to inputs corresponding to the image. not setting same batch correctly may cause panics during inference since images are no longer guaranteed to be in the same batch. --- model/models/qwen25vl/model.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 7de9b6eb..32cca560 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -121,13 +121,14 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 1}) + result = append(result, input.Input{Token: visionStartToken}) // Add the image token with the multimodal tensor data at the first position result = append(result, input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, + SameBatch: patchesPerChunk, }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) From 375839ea2d05a056d02f934f02e953b41f1d444d Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 09:39:38 -0700 Subject: [PATCH 12/26] chore: disable debug in binary libraries (#10788) --- CMakeLists.txt | 2 ++ ml/backend/ggml/ggml/src/ggml-cpu/cpu.go | 2 +- ml/backend/ggml/ggml/src/ggml-metal/metal.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5343d877..ce41288b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) +add_compiler_definitions(NDEBUG) + set(GGML_CPU ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go b/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go index 895d093c..895b7f6e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go +++ b/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go @@ -3,7 +3,7 @@ package cpu // #cgo CFLAGS: -O3 -Wno-implicit-function-declaration // #cgo CXXFLAGS: -std=c++17 // #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include -// #cgo CPPFLAGS: -DGGML_USE_LLAMAFILE +// #cgo CPPFLAGS: -DNDEBUG -DGGML_USE_LLAMAFILE // #cgo linux CPPFLAGS: -D_GNU_SOURCE // #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 // #cgo darwin,arm64 LDFLAGS: -framework Accelerate diff --git a/ml/backend/ggml/ggml/src/ggml-metal/metal.go b/ml/backend/ggml/ggml/src/ggml-metal/metal.go index eb65dfde..0ee017dd 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/metal.go +++ b/ml/backend/ggml/ggml/src/ggml-metal/metal.go @@ -4,6 +4,6 @@ package metal //go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal" -// #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include +// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include // #cgo LDFLAGS: -framework Metal -framework MetalKit import "C" From 139f84cf21f8d8107f69c1404f17a8840c6d67d0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 09:52:52 -0700 Subject: [PATCH 13/26] fix cmakelists (#10804) this fixes an issue introduced in #10788 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce41288b..c005d014 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) -add_compiler_definitions(NDEBUG) +add_compile_definitions(NDEBUG) set(GGML_CPU ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) From e0ed984cde1f6191e38ac2d7f4415ffd619a631f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 10:21:07 -0700 Subject: [PATCH 14/26] feat: qwen3 dense and sparse models (#10708) * feat: qwen3 dense * feat: qwen3moe * fix llama4 moe --- ml/backend.go | 3 + ml/backend/ggml/ggml.go | 14 ++ model/models/llama4/model_text.go | 2 +- model/models/models.go | 1 + model/models/qwen3/model.go | 239 ++++++++++++++++++++++++++++++ 5 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 model/models/qwen3/model.go diff --git a/ml/backend.go b/ml/backend.go index 70497e6e..3c417ef9 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -128,6 +128,8 @@ type Tensor interface { Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor + Div(ctx Context, t2 Tensor) Tensor + Mulmat(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatID(ctx Context, t2, ids Tensor) Tensor @@ -136,6 +138,7 @@ type Tensor interface { LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor + SumRows(ctx Context) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 3ae50ace..44e3b61b 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -887,6 +887,13 @@ func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + } +} + func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -1004,6 +1011,13 @@ func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { } } +func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 829c805c..ff9f5e20 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -82,7 +82,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { - nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) + nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) } return nextStates diff --git a/model/models/models.go b/model/models/models.go index 133e5176..fd935f30 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -8,4 +8,5 @@ import ( _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/qwen25vl" + _ "github.com/ollama/ollama/model/models/qwen3" ) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go new file mode 100644 index 00000000..44c32f9e --- /dev/null +++ b/model/models/qwen3/model.go @@ -0,0 +1,239 @@ +package qwen3 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + eps float32 + ropeBase, ropeScale float32 + + keyLength, valueLength int + + numExperts, numExpertsUsed int + normTopKProb bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +type Attention struct { + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Query *nn.Linear `gguf:"attn_q"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` + Up ml.Tensor `gguf:"ffn_up_exps.weight"` + Down ml.Tensor `gguf:"ffn_down_exps.weight"` +} + +func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + routerLogits := mlp.Router.Forward(ctx, hiddenStates) + + routingWeights := routerLogits.Softmax(ctx) + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx)) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts) + + hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = hiddenStates.SILU(ctx) + hiddenStates = hiddenStates.Mul(ctx, upStates) + + experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP +} + +func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Layers []Layer `gguf:"blk"` + + *Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil +} + +var _ model.Model = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + for i := range layers { + if c.String("general.architecture") == "qwen3moe" { + layers[i].MLP = &sparse{} + } else { + layers[i].MLP = &dense{} + } + } + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func init() { + model.Register("qwen3", New) + model.Register("qwen3moe", New) +} From c890011322fbdd325ef9f16e425fe1f5213a24fe Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 10:21:24 -0700 Subject: [PATCH 15/26] feat: port qwen2 model (#10782) --- model/models/llama/model.go | 47 +++++----- model/models/models.go | 1 + model/models/qwen2/model.go | 170 ++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 model/models/qwen2/model.go diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 7b475512..507f1ebc 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -75,30 +75,31 @@ type SelfAttention struct { RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) - q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query := sa.Query.Forward(ctx, hiddenState) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key := sa.Key.Forward(ctx, hiddenState) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - return sa.Output.Forward(ctx, kqv) + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, attention) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -119,11 +120,11 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts) // In the final layer (outputs != nil), optimize by pruning to just the token positions // we need logits for. @@ -146,22 +147,20 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } - hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) - var lastLayerOutputs ml.Tensor + var outputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } } - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/models.go b/model/models/models.go index fd935f30..5471ce89 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,6 +7,7 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" ) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go new file mode 100644 index 00000000..3c3d81aa --- /dev/null +++ b/model/models/qwen2/model.go @@ -0,0 +1,170 @@ +package qwen2 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) + + query := attn.Query.Forward(ctx, hiddenStates) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key := attn.Key.Forward(ctx, hiddenStates) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + value := attn.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + + return attn.Output.Forward(ctx, attention) +} + +type MLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type DecoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.MLP.Forward(ctx, hiddenStates) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []DecoderLayer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +// Forward implements model.Model. +func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + hiddenStates = m.Output.Forward(ctx, hiddenStates) + return hiddenStates, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil +} + +func New(c fs.Config) (model.Model, error) { + m := Model{ + Layers: make([]DecoderLayer, c.Uint("block_count")), + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + eps: c.Float("attention.layer_norm_rms_epsilon"), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func init() { + model.Register("qwen2", New) +} From 7359b0270767d1ebda33598d738a4263a4238b3a Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 21 May 2025 10:46:56 -0700 Subject: [PATCH 16/26] win: detect background upgrade in progress (#10785) Give the user a helpful error instead of showing connection refused errors. --- cmd/cmd.go | 4 ++-- cmd/start_windows.go | 50 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index ad4be7f9..b9047529 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1236,11 +1236,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { return err } if err := client.Heartbeat(cmd.Context()); err != nil { - if !strings.Contains(err.Error(), " refused") { + if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) { return err } if err := startApp(cmd.Context(), client); err != nil { - return errors.New("could not connect to ollama app, is it running?") + return fmt.Errorf("ollama server not responding - %w", err) } } return nil diff --git a/cmd/start_windows.go b/cmd/start_windows.go index 5bca2433..bcc51057 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -4,17 +4,27 @@ import ( "context" "errors" "fmt" + "log/slog" "os" "os/exec" + "path" "path/filepath" "strings" "syscall" + "unsafe" "github.com/ollama/ollama/api" + "golang.org/x/sys/windows" +) + +const ( + Installer = "OllamaSetup.exe" ) func startApp(ctx context.Context, client *api.Client) error { - // log.Printf("XXX Attempting to find and start ollama app") + if len(isProcRunning(Installer)) > 0 { + return fmt.Errorf("upgrade in progress...") + } AppName := "ollama app.exe" exe, err := os.Executable() if err != nil { @@ -56,3 +66,41 @@ func startApp(ctx context.Context, client *api.Client) error { } return waitForServer(ctx, client) } + +func isProcRunning(procName string) []uint32 { + pids := make([]uint32, 2048) + var ret uint32 + if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 { + slog.Debug("failed to check for running installers", "error", err) + return nil + } + pids = pids[:ret] + var matches []uint32 + for _, pid := range pids { + if pid == 0 { + continue + } + hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid) + if err != nil { + continue + } + defer windows.CloseHandle(hProcess) + var module windows.Handle + var cbNeeded uint32 + cb := (uint32)(unsafe.Sizeof(module)) + if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil { + continue + } + var sz uint32 = 1024 * 8 + moduleName := make([]uint16, sz) + cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0))) + if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { + continue + } + exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName))) + if strings.EqualFold(exeFile, procName) { + matches = append(matches, pid) + } + } + return matches +} From 61aeaf7e813bf307f1b28480c7ee2aed639b28f7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 13:55:31 -0700 Subject: [PATCH 17/26] remove support for multiple ggufs in a single file (#10722) * remove support for multiple ggufs in a single file this was an attempt to make it easier to import multimodal models into ollama. this was rarely used and error prone so remove it * fix: create fused model from blob --- server/create.go | 51 ++++++++++++++---------------------------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/server/create.go b/server/create.go index 9488de35..bd970876 100644 --- a/server/create.go +++ b/server/create.go @@ -501,48 +501,27 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML return nil, errOnlyGGUFSupported } - stat, err := blob.Stat() + f, err := ggml.Decode(blob, -1) if err != nil { return nil, err } - var offset int64 - for offset < stat.Size() { - f, err := ggml.Decode(blob, -1) - if errors.Is(err, io.EOF) { - break - } else if err != nil { - return nil, err - } - - mediatype := "application/vnd.ollama.image.model" - if f.KV().Kind() == "adapter" { - mediatype = "application/vnd.ollama.image.adapter" - } else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" { - mediatype = "application/vnd.ollama.image.projector" - } - - var layer Layer - if digest != "" && f.Length == stat.Size() && offset == 0 { - layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) - if err != nil { - slog.Debug("could not create new layer from layer", "error", err) - return nil, err - } - } - - // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) - if layer.Digest == "" { - layer, err = NewLayer(io.NewSectionReader(blob, offset, f.Length), mediatype) - if err != nil { - return nil, err - } - } - - layers = append(layers, &layerGGML{layer, f}) - offset = f.Length + mediatype := "application/vnd.ollama.image.model" + if f.KV().Kind() == "adapter" { + mediatype = "application/vnd.ollama.image.adapter" + } else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" { + // if a model has vision.block_count but not block_count, it is a standalone vision model + mediatype = "application/vnd.ollama.image.projector" } + layer, err := NewLayerFromLayer(digest, mediatype, blob.Name()) + if err != nil { + slog.Debug("could not create new layer from layer", "error", err) + return nil, err + } + + layers = append(layers, &layerGGML{layer, f}) + return detectChatTemplate(layers) } From fdd4d479a3cbaff1a7fe849e38a145652f87d611 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 22 May 2025 09:12:32 -0700 Subject: [PATCH 18/26] integration: add qwen2.5-vl (#10815) Replace the older llava model with qwen2.5 for vision tests Skip split-batch test on small VRAM systems to avoid excessive test time --- integration/llm_image_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index b9726c8f..bbd031a9 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) { } testCases := []testCase{ { - model: "llava:7b", + model: "qwen2.5vl", }, { model: "llama3.2-vision", @@ -60,6 +60,7 @@ func TestVisionModels(t *testing.T) { } func TestIntegrationSplitBatch(t *testing.T) { + skipUnderMinVRAM(t, 6) image, err := base64.StdEncoding.DecodeString(imageEncoding) require.NoError(t, err) req := api.GenerateRequest{ From fbe6ae285a23baddb14c5bbce26d4fcb837503e4 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 22 May 2025 10:48:08 -0700 Subject: [PATCH 19/26] server: improve tensor quantization fallback logic (#10806) Fall back to alternative quantization types when a tensor's dimensions aren't divisible by the block size required for the original desired quantization type. If retried quantization types fail, the system ultimately falls back to F16 (half-precision floating point) which has a block size of 1 and can handle any tensor dimension. --- server/quantization.go | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/server/quantization.go b/server/quantization.go index adfc948e..e57e8a4d 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -120,14 +120,30 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType if newType.IsQuantized() { nx := shape[0] - ny := uint64(1) - if len(shape) > 1 { - ny = shape[1] - } qk_k := newType.BlockSize() + + // Check if first dimension is divisible by block size if nx%qk_k != 0 { - slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String())) - newType = fsggml.TensorTypeF16 + // Store the original type for logging + originalType := newType + + // Select appropriate fallback based on original type + switch newType { + case fsggml.TensorTypeQ4_K: + newType = fsggml.TensorTypeQ5_0 + case fsggml.TensorTypeQ5_K: + newType = fsggml.TensorTypeQ5_1 + case fsggml.TensorTypeQ6_K: + newType = fsggml.TensorTypeQ8_0 + } + + // Final check - if still incompatible, fall back to F16 + if nx%newType.BlockSize() != 0 { + newType = fsggml.TensorTypeF16 + } + + slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s", + nx, qk_k, originalType.String(), newType.String())) } } return newType From adff143bcda0c7ab4ca3a85dc3db5a81552368c7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 22 May 2025 11:30:49 -0700 Subject: [PATCH 20/26] fix: mllama quality (#10807) * fix mllama convert - transform attn_gate and ffn_gate - swap attention heads for vision models * fix mllama the mlp gate which was applied in the wrong place --- convert/convert_mllama.go | 65 +++++++++++++++++++---------- model/models/mllama/model_vision.go | 22 ++-------- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/convert/convert_mllama.go b/convert/convert_mllama.go index 12478be7..69d7f588 100644 --- a/convert/convert_mllama.go +++ b/convert/convert_mllama.go @@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { var out []*ggml.Tensor var text []Tensor for _, t := range ts { - if t.Name() == "v.position_embd.gate" { + if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") { + text = append(text, t) + } else if t.Name() == "v.position_embd.gate" { for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} { tt := t.Clone() tt.SetRepacker(m.repack(name)) @@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { WriterTo: tt, }) } - } else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" { - t.SetRepacker(m.repack(t.Name())) - out = append(out, &ggml.Tensor{ - Name: t.Name(), - Kind: t.Kind(), - Shape: t.Shape(), - WriterTo: t, - }) - } else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") { - out = append(out, &ggml.Tensor{ - Name: t.Name(), - Kind: t.Kind(), - Shape: t.Shape(), - WriterTo: t, - }) } else { - text = append(text, t) + if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" { + t.SetRepacker(m.repack(t.Name())) + } else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") { + t.SetRepacker(m.repack(t.Name())) + } else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") { + t.SetRepacker(m.repack(t.Name())) + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) } } @@ -137,16 +137,35 @@ func (m *mllamaModel) repack(name string) Repacker { var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) - t, err = tensor.Tanh(t) - if err != nil { - return nil, err - } + if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") { + heads := m.VisionModel.AttentionHeads + if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } - if name == "v.position_embd.gate" { - t, err = tensor.Sub(float32(1), t) + if err := t.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := t.Reshape(dims...); err != nil { + return nil, err + } + + if err := t.Transpose(); err != nil { + return nil, err + } + } else { + t, err = tensor.Tanh(t) if err != nil { return nil, err } + + if name == "v.position_embd.gate" { + t, err = tensor.Sub(float32(1), t) + if err != nil { + return nil, err + } + } } t = tensor.Materialize(t) diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 77ea5373..2d424947 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -16,8 +16,6 @@ type VisionSelfAttention struct { Key *nn.Linear `gguf:"attn_k"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` - - Gate ml.Tensor `gguf:"attn_gate"` } func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { @@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) - - hiddenState = sa.Output.Forward(ctx, attention) - return hiddenState + return sa.Output.Forward(ctx, attention) } type VisionMLP struct { @@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts // self attention hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts) - if e.AttentionGate != nil { hiddenState = hiddenState.Mul(ctx, e.AttentionGate) } hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState - // feed forward hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.MLP.Forward(ctx, hiddenState, opts) - hiddenState = hiddenState.Add(ctx, residual) if e.MLPGate != nil { hiddenState = hiddenState.Mul(ctx, e.MLPGate) } - + hiddenState = hiddenState.Add(ctx, residual) return hiddenState } From d950ff12c09c07a1cda7242373071fb9e7af9ddc Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 22 May 2025 14:31:36 -0700 Subject: [PATCH 21/26] sched: fix runner leak during reloading unload (#10819) When the same model is being reloaded rapidly with client connections being canceled before the model finishes loading, the queued unload event could cause a leak of runners by deleting a different runner from the loaded list. --- server/sched.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/server/sched.go b/server/sched.go index 3fc54e55..612e4702 100644 --- a/server/sched.go +++ b/server/sched.go @@ -387,6 +387,17 @@ func (s *Scheduler) processCompleted(ctx context.Context) { s.loadedMu.Unlock() runner.refMu.Unlock() slog.Debug("duplicate expired event, ignoring", "runner", runner) + } else if runner.pid != runnerToUnload.pid { + // If the pids do not match, we likely had multiple load + // failures for the same model in quick succession due to + // request context canceled and are draining the queue of + // events. Ensure the orphaned runner is properly shut down, but + // do not delete the mismatched loaded runner, or wait for VRAM + // convergence. + slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload) + runner.unload() + s.loadedMu.Unlock() + runner.refMu.Unlock() } else { slog.Debug("starting background wait for VRAM recovery", "runner", runner) finished := runner.waitForVRAMRecovery() From 6db8a3771c29d070ef165cca0d7e8dbda3fc341e Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 16 May 2025 14:05:08 -0700 Subject: [PATCH 22/26] ggml: Report graph memory for failed allocations GGML has a function to report the allocated size of a backend buffer. However, this returns 0 if we tried to allocate a buffer and it failed. For memory management purposes, it's important to know how much we were trying to allocate. This extends the API to report attempted sizes for all buffers and whether it succeeeded. --- ...16-graph-memory-reporting-on-failure.patch | 156 ++++++++++++++++++ ml/backend/ggml/ggml/include/ggml-alloc.h | 6 + ml/backend/ggml/ggml/include/ggml-backend.h | 6 + ml/backend/ggml/ggml/src/ggml-alloc.c | 38 ++++- ml/backend/ggml/ggml/src/ggml-backend.cpp | 10 ++ 5 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 llama/patches/0016-graph-memory-reporting-on-failure.patch diff --git a/llama/patches/0016-graph-memory-reporting-on-failure.patch b/llama/patches/0016-graph-memory-reporting-on-failure.patch new file mode 100644 index 00000000..92188224 --- /dev/null +++ b/llama/patches/0016-graph-memory-reporting-on-failure.patch @@ -0,0 +1,156 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jesse Gross +Date: Fri, 18 Apr 2025 15:58:19 -0700 +Subject: [PATCH] graph memory reporting on failure + +--- + ggml/include/ggml-alloc.h | 6 ++++++ + ggml/include/ggml-backend.h | 6 ++++++ + ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++---- + ggml/src/ggml-backend.cpp | 10 ++++++++++ + 4 files changed, 56 insertions(+), 4 deletions(-) + +diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h +index 2cb150fd..781b1e10 100644 +--- a/ggml/include/ggml-alloc.h ++++ b/ggml/include/ggml-alloc.h +@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph + + GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); + ++struct ggml_allocr_buffer_status { ++ size_t size; ++ bool allocated; ++}; ++GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); ++ + // Utils + // Create a buffer and allocate all the tensors in a ggml_context + GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index 778927f6..74e46716 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -304,6 +304,12 @@ extern "C" { + + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + ++ struct ggml_backend_buffer_status { ++ size_t size; ++ bool allocated; ++ }; ++ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); ++ + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + +diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c +index 5fd379f6..04812990 100644 +--- a/ggml/src/ggml-alloc.c ++++ b/ggml/src/ggml-alloc.c +@@ -364,6 +364,7 @@ struct node_alloc { + struct ggml_gallocr { + ggml_backend_buffer_type_t * bufts; // [n_buffers] + ggml_backend_buffer_t * buffers; // [n_buffers] ++ size_t *buffer_sizes; // [n_buffers] + struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] + int n_buffers; + +@@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs + galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); + GGML_ASSERT(galloc->buffers != NULL); + ++ galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); ++ GGML_ASSERT(galloc->buffer_sizes != NULL); ++ + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); + GGML_ASSERT(galloc->buf_tallocs != NULL); + +@@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { + ggml_hash_set_free(&galloc->hash_set); + free(galloc->hash_values); + free(galloc->bufts); ++ free(galloc->buffer_sizes); + free(galloc->buffers); + free(galloc->buf_tallocs); + free(galloc->node_allocs); +@@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c + } + } + ++ bool success = true; ++ + // reallocate buffers if needed + for (int i = 0; i < galloc->n_buffers; i++) { + // if the buffer type is used multiple times, we reuse the same buffer +@@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c + + ggml_backend_buffer_free(galloc->buffers[i]); + galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); +- if (galloc->buffers[i] == NULL) { ++ if (galloc->buffers[i]) { ++ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); ++ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); ++ } else { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); +- return false; ++ galloc->buffer_sizes[i] = new_size; ++ success = false; + } +- ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); ++ } else { ++ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); + } + } + +- return true; ++ return success; + } + + bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { +@@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); + } + ++struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { ++ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); ++ ++ for (int i = 0; i < buffer_id; i++) { ++ if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) { ++ // This buffer is the same as a previous one due to the same buffer type being used multiple times ++ // (See above.) However, we need a different check because multiple buffers might be NULL in our ++ // case and we still want to know the attempted size. ++ ++ struct ggml_allocr_buffer_status status = {0, true}; ++ return status; ++ } ++ } ++ ++ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; ++ return status; ++} ++ + // utils + + static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { +diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp +index 0ce73a99..be335e8c 100644 +--- a/ggml/src/ggml-backend.cpp ++++ b/ggml/src/ggml-backend.cpp +@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe + return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); + } + ++struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { ++ int backend_index = ggml_backend_sched_backend_id(sched, backend); ++ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); ++ ++ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; ++ ++ return status; ++} ++ + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h index 2cb150fd..781b1e10 100644 --- a/ml/backend/ggml/ggml/include/ggml-alloc.h +++ b/ml/backend/ggml/ggml/include/ggml-alloc.h @@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); +struct ggml_allocr_buffer_status { + size_t size; + bool allocated; +}; +GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); + // Utils // Create a buffer and allocate all the tensors in a ggml_context GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index 778927f6..74e46716 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -304,6 +304,12 @@ extern "C" { GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + struct ggml_backend_buffer_status { + size_t size; + bool allocated; + }; + GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 5fd379f6..04812990 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -364,6 +364,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] ggml_backend_buffer_t * buffers; // [n_buffers] + size_t *buffer_sizes; // [n_buffers] struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; @@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); GGML_ASSERT(galloc->buffers != NULL); + galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); + GGML_ASSERT(galloc->buffer_sizes != NULL); + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); @@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); + free(galloc->buffer_sizes); free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); @@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } + bool success = true; + // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer @@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); - if (galloc->buffers[i] == NULL) { + if (galloc->buffers[i]) { + galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); + ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } else { GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); - return false; + galloc->buffer_sizes[i] = new_size; + success = false; } - ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } else { + galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); } } - return true; + return success; } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { @@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } +struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); + + for (int i = 0; i < buffer_id; i++) { + if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) { + // This buffer is the same as a previous one due to the same buffer type being used multiple times + // (See above.) However, we need a different check because multiple buffers might be NULL in our + // case and we still want to know the attempted size. + + struct ggml_allocr_buffer_status status = {0, true}; + return status; + } + } + + struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; + return status; +} + // utils static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 0ce73a99..be335e8c 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } +struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); + struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; + + return status; +} + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); From 73d6a82cce18f84ff5c67148783224cf25b30b32 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 17 Apr 2025 11:00:25 -0700 Subject: [PATCH 23/26] ollamarunner: Memory usage reporting This provides granular information about the backend memory allocations required by the runner: - Per backend - Per layer - Weights, cache and graph - Allocation status This can be used for debugging and validating memory estimates. --- kvcache/causal_test.go | 2 +- ml/backend.go | 84 ++++++++++++++- ml/backend/ggml/ggml.go | 163 ++++++++++++++++++++---------- runner/ollamarunner/multimodal.go | 5 +- runner/ollamarunner/runner.go | 48 +++++---- 5 files changed, 224 insertions(+), 78 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 79698708..820d496d 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} -func (c *testContext) Reserve() error { return nil } +func (c *testContext) Reserve() {} func (c *testContext) MaxGraphNodes() int { return 10 diff --git a/ml/backend.go b/ml/backend.go index 3c417ef9..7c9b9e31 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -15,6 +15,10 @@ import ( type Backend interface { Load(ctx context.Context, progress func(float32)) error + + // BackendMemory returns the memory allocations that were made for this model + BackendMemory() BackendMemory + Config() fs.Config Get(name string) Tensor NewContext() Context @@ -68,6 +72,84 @@ type BackendParams struct { FlashAttention bool } +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} + +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +type AllocationStatus int + +const ( + // Unallocated memory - have not yet attempted to allocate + Unallocated AllocationStatus = iota + + // Failed memory - tried to allocate the memory and did not succeed + Failed + + // Allocated memory = tried and succeeded to allocate memory + Allocated +) + +// Memory is the size of an allocation and whether it was successful. +type Memory struct { + Size uint64 + Status AllocationStatus +} + +func (m Memory) String() string { + s := fmt.Sprint(m.Size) + + switch m.Status { + case Unallocated: + s += "U" + case Failed: + s += "F" + case Allocated: + s += "A" + } + + return s +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []Memory + + // Cache is the per-layer memory needed for the KV cache. + Cache []Memory + + // Graph is the size of the compute graph. It is not per-layer. + Graph Memory +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputsWeights are always located on the CPU and cannot be moved + InputWeights Memory + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + var backends = make(map[string]func(string, BackendParams) (Backend, error)) func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { @@ -102,7 +184,7 @@ type Context interface { // graph, simply preallocates memory. Typically called with a // worst case graph to ensure all resources are available for // for future inference. - Reserve() error + Reserve() MaxGraphNodes() int Close() diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 44e3b61b..496ba8a6 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,7 +10,6 @@ import "C" import ( "context" - "errors" "fmt" "io" "log/slog" @@ -66,6 +65,12 @@ type Backend struct { // layers is the backend used for repeating layers layers map[int]*C.struct_ggml_backend_buffer_type + // requiredMemory is the cumulative memory allocations needed by the backend + requiredMemory *ml.BackendMemory + + // btDeviceMemory maps from a buffer type to the memory allocations associated with that device + btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory + flashAttention bool // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler @@ -94,6 +99,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { "num_key_values", len(meta.KV()), ) + var requiredMemory ml.BackendMemory + btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory) + type deviceBufferType struct { d *C.struct_ggml_backend_device bts []*C.struct_ggml_backend_buffer_type @@ -114,6 +122,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } + blocks := int(meta.KV().BlockCount()) + // create list of buffer types for the cpu cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)} for _, d := range append(accels, append(gpus, cpus...)...) { @@ -121,17 +131,27 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { case C.GGML_BACKEND_DEVICE_TYPE_CPU, C.GGML_BACKEND_DEVICE_TYPE_ACCEL: cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) + btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } } + requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d)) + requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) + requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) + // create list of buffer types for each gpu var gpuDeviceBufferTypes []deviceBufferType - for _, d := range gpus { + requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus)) + for i, d := range gpus { bt := C.ggml_backend_dev_buffer_type(d) gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ d: d, bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...), }) + btDeviceMemory[bt] = &requiredMemory.GPUs[i] + requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) + requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) + requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) } useDefaultSplit := true @@ -170,8 +190,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { // inputs always use cpu input := cpuDeviceBufferType - blocks := int(meta.KV().BlockCount()) - // define a range of gpu layers. anything outside of this range is assigned to the cpu gpuRangeStart := max(0, blocks-params.NumGPULayers) gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) @@ -212,7 +230,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { // contexts are shared by tensors of the same buffer type ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) - createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { + createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor { for _, bt := range bts { if _, ok := ctxs[bt]; !ok { ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ @@ -238,6 +256,16 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { C.ggml_set_name(tt, cname) slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + + size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) + if layer == -1 { + // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case + requiredMemory.InputWeights.Status = ml.Allocated + requiredMemory.InputWeights.Size += uint64(size) + } else { + btDeviceMemory[bt].Weights[layer].Size += uint64(size) + } + //nolint:staticcheck // TODO: check if buffer type supports this tensor return tt } @@ -259,22 +287,22 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { for _, t := range meta.Tensors().Items() { switch { case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" { - createTensor(tensor{source: t, target: "output.weight"}, output.bts) + createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks) } case contains(t.Name, "cls", "output", "output_norm"): - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): // TODO: assign vision tensors to the gpu if possible - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): // these tensors should be repeated per layer for i, layer := range layers { createTensor(tensor{ source: t, target: "blk." + strconv.Itoa(i) + "." + t.Name, - }, layer.bts) + }, layer.bts, i) } default: layerIndex := -1 @@ -285,10 +313,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } if layerIndex >= 0 { - createTensor(tensor{source: t}, layers[layerIndex].bts) + createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex) } else { // load all other tensors on the cpu - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) } } } @@ -301,8 +329,18 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + for i := range btDeviceMemory[bt].Weights { + if btDeviceMemory[bt].Weights[i].Size != 0 { + if b != nil { + btDeviceMemory[bt].Weights[i].Status = ml.Allocated + } else { + btDeviceMemory[bt].Weights[i].Status = ml.Failed + } + } + } + if b == nil { - return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt))) + panic(ml.ErrNoMem{BackendMemory: requiredMemory}) } C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) @@ -367,7 +405,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } return m }(), - maxGraphNodes: maxGraphNodes, + requiredMemory: &requiredMemory, + btDeviceMemory: btDeviceMemory, + maxGraphNodes: maxGraphNodes, }, nil } @@ -446,6 +486,10 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { return nil } +func (b *Backend) BackendMemory() ml.BackendMemory { + return *b.requiredMemory +} + func (b *Backend) Config() fs.Config { return b.meta.KV() } @@ -477,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { no_alloc: true, }), allocatedBuffers: &allocatedBuffers, + layer: -1, } } @@ -503,6 +548,9 @@ type Context struct { // maxGraphNodes is the maximum allowed number of graph nodes in this context maxGraphNodes int + + // layer is the graph layer that this context is allocating for - assumed to be cache + layer int } func (c *Context) Input() ml.Context { @@ -513,6 +561,7 @@ func (c *Context) Input() ml.Context { buft: c.b.input, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: -1, } } @@ -527,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context { buft: buft, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: i, } } @@ -564,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } -func (c *Context) Reserve() error { - if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { - C.ggml_backend_sched_reset(c.b.sched) - return errors.New("failed to reserve graph") - } +func (c *Context) Reserve() { + reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph) slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) - for i := range c.b.schedBackends { - size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), - "size", format.HumanBytes2(uint64(size))) + + // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations + for _, bt := range c.b.schedBufts { + c.b.btDeviceMemory[bt].Graph = ml.Memory{} } - C.ggml_backend_sched_reset(c.b.sched) + for i := range c.b.schedBackends { + bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) - return nil + graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph + graph.Size += uint64(bufferStatus.size) + if bufferStatus.allocated && graph.Status != ml.Failed { + graph.Status = ml.Allocated + } else { + graph.Status = ml.Failed + } + + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), + "size", format.HumanBytes2(uint64(bufferStatus.size))) + } + + if !reserved { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } } func (c *Context) MaxGraphNodes() int { @@ -599,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { +func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { if c.buft == nil { panic("set Input or Layer before creating tensors") } @@ -622,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { if len(shape) < 1 || shape[0] == 0 { var shape C.int64_t = 0 - return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil + return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} } else if len(shape) > 4 { panic("unsupported number of dimensions") } @@ -635,31 +697,34 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) - b := C.ggml_backend_buft_alloc_buffer(c.buft, size) - if b == nil { - return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft))) - } - *c.allocatedBuffers = append(*c.allocatedBuffers, b) + b := C.ggml_backend_buft_alloc_buffer(c.buft, size) + if c.layer >= 0 { + cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] + + cache.Size += uint64(size) + if b != nil { + cache.Status = ml.Allocated + } else { + cache.Status = ml.Failed + } + } + + if b == nil { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } + + *c.allocatedBuffers = append(*c.allocatedBuffers, b) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - return &Tensor{b: c.b, t: t}, nil + return &Tensor{b: c.b, t: t} } func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - - return t + return c.newTensor(dtype, shape) } func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - + t := c.newTensor(dtype, shape) C.ggml_set_zero(t.(*Tensor).t) return t } @@ -687,10 +752,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { return nil, err } - t, err := c.newTensor(ml.DTypeF32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeF32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) @@ -704,10 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return nil, err } - t, err := c.newTensor(ml.DTypeI32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeI32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index d78612fe..dbe6bba1 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -95,10 +95,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten } } } else { - err := computeCtx.Reserve() - if err != nil { - return nil, err - } + computeCtx.Reserve() } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a488a104..99bee106 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -826,16 +826,12 @@ func (s *Server) reserveWorstCaseGraph() error { return err } - err = ctx.Forward(t).Reserve() - if err != nil { - return err - } + ctx.Forward(t).Reserve() return nil } -func (s *Server) loadModel( - ctx context.Context, +func (s *Server) initModel( mpath string, params ml.BackendParams, lpath multiLPath, @@ -843,21 +839,21 @@ func (s *Server) loadModel( kvCacheType string, kvSize int, multiUserCache bool, -) { +) error { var err error s.model, err = model.New(mpath, params) if err != nil { - panic(err) + return err } // TODO(jessegross): LoRA loading if lpath.String() != "" { - panic("loras are not yet implemented") + return errors.New("loras are not yet implemented") } s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { - panic(err) + return err } if !s.cache.enabled && parallel > 1 { @@ -869,11 +865,25 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - err = s.reserveWorstCaseGraph() + return s.reserveWorstCaseGraph() +} + +func (s *Server) load( + ctx context.Context, + mpath string, + params ml.BackendParams, + lpath multiLPath, + parallel int, + kvCacheType string, + kvSize int, + multiUserCache bool) { + err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) if err != nil { panic(err) } + slog.Debug("memory", "allocated", s.model.Backend().BackendMemory()) + err = s.model.Backend().Load(ctx, func(progress float32) { s.progress = progress @@ -921,9 +931,14 @@ func Execute(args []string) error { status: llm.ServerStatusLoadingModel, } + server.cond = sync.NewCond(&server.mu) + server.ready.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // TODO(jessegross): Parameters that need to be implemented: // no-mmap - // mlock var tensorSplitFloats []float32 if *tensorSplit != "" { @@ -943,14 +958,7 @@ func Execute(args []string) error { FlashAttention: *flashAttention, } - server.ready.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) - - server.cond = sync.NewCond(&server.mu) - + go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) go server.run(ctx) addr := "127.0.0.1:" + strconv.Itoa(*port) From 1f371ea92f7ebe4edd208b6732753473b2c4d0cd Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 19 May 2025 10:43:56 -0700 Subject: [PATCH 24/26] ml: Panic rather than return error on tensor allocation failure FromFloatSlice and FromIntSlice return an error if the shape doesn't match the passed data or if memory can't be allocated. Since these are inputs, the memory being allocated is system memory rather than VRAM. In many cases, the caller can't really handle the error and panics. Empty and Zeros directly panic if they can't allocate memory. This makes things consistent by panicing for the first two cases, removing a fair amount of error handling code. This is also consistent with how Go typically handles these situations. --- kvcache/causal.go | 26 ++++++---------------- kvcache/causal_test.go | 18 ++++++++-------- ml/backend.go | 4 ++-- ml/backend/ggml/ggml.go | 31 +++++++++------------------ model/model.go | 6 +----- model/models/gemma2/model.go | 11 ++-------- model/models/gemma3/model.go | 16 +++----------- model/models/llama/model.go | 10 ++------- model/models/llama4/model.go | 22 ++++--------------- model/models/llama4/model_text.go | 6 +----- model/models/llama4/model_vision.go | 5 +---- model/models/mistral3/model.go | 16 +++----------- model/models/mistral3/model_vision.go | 16 +++----------- model/models/mllama/model.go | 22 ++++--------------- model/models/qwen2/model.go | 10 ++------- model/models/qwen25vl/model.go | 16 +++----------- model/models/qwen25vl/model_vision.go | 22 +++++-------------- model/models/qwen3/model.go | 10 ++------- runner/ollamarunner/multimodal.go | 2 +- runner/ollamarunner/runner.go | 8 +++---- 20 files changed, 68 insertions(+), 209 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 9bc1d5da..f6bacaaf 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curCellRange.max = len(c.cells) - 1 } - var err error - c.curMask, err = c.buildMask(ctx) + c.curMask = c.buildMask(ctx) - return err + return nil } func newRange() cellRange { @@ -297,7 +296,7 @@ func roundUp(length, pad int) int { // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). -func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { +func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { // Align and pad the two dimensions as required by the backend batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) @@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask[i] = float32(math.Inf(-1)) } - maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) - if err != nil { - return nil, err - } + maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) if c.config.MaskDType != ml.DTypeF32 { out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) @@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { maskTensor = out } - return maskTensor, nil + return maskTensor } func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { @@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { if !slices.Equal(c.opts.Except, opts.Except) { c.opts = opts if ctx != nil { - var err error - c.curMask, err = c.buildMask(ctx) - if err != nil { - // This error should never occur because we have previously built a mask with the same shape - panic(fmt.Errorf("SetCausal: %w", err)) - } + c.curMask = c.buildMask(ctx) } } } @@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } } - kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) - if err != nil { - return err - } + kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) for i, key := range c.keys { if key == nil { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 820d496d..5b1dbe86 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice(test.in, test.inShape...) + tensor := context.FromFloatSlice(test.in, test.inShape...) cache.Put(context, tensor, tensor) out, _, mask := cache.Get(context) @@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return c.Empty(dtype, shape...) } -func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) - return t, nil + return t } -func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { f := make([]float32, len(s)) for i := range f { f[i] = float32(s[i]) } - out, _ := c.FromFloatSlice(f, shape...) + out := c.FromFloatSlice(f, shape...) out.(*testTensor).dtype = ml.DTypeI32 - return out, nil + return out } func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso s = append(s, i) } - out, _ := c.FromFloatSlice(s, len(s)) + out := c.FromFloatSlice(s, len(s)) out.(*testTensor).dtype = dtype return out } diff --git a/ml/backend.go b/ml/backend.go index 7c9b9e31..6beb7d2b 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -171,8 +171,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) { type Context interface { Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor - FromFloatSlice(s []float32, shape ...int) (Tensor, error) - FromIntSlice(s []int32, shape ...int) (Tensor, error) + FromFloatSlice(s []float32, shape ...int) Tensor + FromIntSlice(s []int32, shape ...int) Tensor // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. Arange(start, stop, step float32, dtype DType) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 496ba8a6..76172ae1 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -729,11 +729,11 @@ func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return t } -func checkShape[S ~[]E, E any](s S, shape ...int) error { +func checkShape[S ~[]E, E any](s S, shape ...int) { n := len(s) if n == 0 { - return nil + return } for _, v := range shape { @@ -741,16 +741,12 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error { } if n != 1 { - return fmt.Errorf("invalid shape: %v", shape) + panic(fmt.Errorf("invalid shape: %v", shape)) } - - return nil } -func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor { + checkShape(s, shape...) t := c.newTensor(ml.DTypeF32, shape) @@ -758,13 +754,11 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } -func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor { + checkShape(s, shape...) t := c.newTensor(ml.DTypeI32, shape) @@ -772,7 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -790,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { arange = append(arange, int32(i)) } - t, err := c.Input().FromIntSlice(arange, len(arange)) - if err != nil { - panic(err) - } - - return t + return c.Input().FromIntSlice(arange, len(arange)) default: panic("unsupported dtype for arange") } diff --git a/model/model.go b/model/model.go index 39b68db1..25097e01 100644 --- a/model/model.go +++ b/model/model.go @@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten return nil, errors.New("batch size cannot be less than 1") } - var err error - batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) - if err != nil { - return nil, err - } + batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs)) cache := m.Config().Cache if cache != nil { diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 3c5a7ea5..e621d03a 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 89d1788e..53bf8275 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, + pixelValues := ctx.Input().FromFloatSlice(f32s, m.ImageProcessor.imageSize, m.ImageProcessor.imageSize, m.ImageProcessor.numChannels, ) - if err != nil { - return nil, err - } visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) @@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 507f1ebc..3cf782d0 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index af5173a1..8084760b 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) - if err != nil { - return nil, err - } + tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize @@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input pixelValues := tilesLocal if len(pixelsGlobal) > 0 { - tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) - if err != nil { - return nil, err - } - + tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) } @@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index ff9f5e20..27935f40 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) } - var err error - attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) - if err != nil { - panic(err) - } + attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) } for i, layer := range m.Layers { diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index e6b1afef..dc6f82b8 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { } } - ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) - if err != nil { - panic(err) - } + ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index dd01a587..9d662fc1 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) visionOutputs := m.VisionModel.Forward(ctx, pixelValues) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) @@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 24541004..65bdcff2 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) } } - h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } - - w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } + h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) + w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) @@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { } } - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - panic(err) - } + positionIDs := ctx.Input().FromIntSlice(positions, len(positions)) positionEmbedding := m.positionalEmbedding(ctx, positionIDs) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 547e2cb3..45cb3e02 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) - if err != nil { - return nil, err - } - - aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) + aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) @@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c3d81aa..42338d0d 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -100,10 +100,7 @@ type Model struct { // Forward implements model.Model. func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 32cca560..ee38cad9 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * m.ImageProcessor.patchSize * m.ImageProcessor.patchSize numPatches := grid.Temporal * grid.Height * grid.Width - pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) - if err != nil { - return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err) - } + pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) return pixelValues, grid, nil } @@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 01eef392..4d7afaa1 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -1,7 +1,6 @@ package qwen25vl import ( - "fmt" "math" "slices" @@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int } } - mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) - if err != nil { - panic(err) - } + mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + // Reshape to match [seqLength, seqLength, 1] for broadcasting mask = mask.Reshape(ctx, seqLength, seqLength, 1) @@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) } } - t, err := ctx.Input().FromIntSlice(index, len(index)) - if err != nil { - panic(err) - } + t := ctx.Input().FromIntSlice(index, len(index)) return t, bounds } @@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) } } - freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) - if err != nil { - panic(fmt.Errorf("failed to create tensor from frequencies: %w", err)) - } + freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) // Create position coordinates (y,x pairs) for the grid // In PyTorch: Equivalent to generating position ids with torch.arange() @@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor coords = append(coords, int32(y), int32(x)) } } - pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) - if err != nil { - panic(fmt.Errorf("failed to create tensor from positions: %w", err)) - } + pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) // Reshape and permute positions to match spatial merging pattern pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 44c32f9e..1930da7e 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -156,10 +156,7 @@ type Model struct { // Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index dbe6bba1..fbdc7d72 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten for i, t := range entry.mm { if in == t.Tensor { if !reserve { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil } else { return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 99bee106..a7a889f1 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Outputs[i] = int32(i) } - batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) - if err != nil { - return err - } + batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) cache := s.model.Config().Cache if cache != nil { @@ -876,7 +873,8 @@ func (s *Server) load( parallel int, kvCacheType string, kvSize int, - multiUserCache bool) { + multiUserCache bool, +) { err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) if err != nil { panic(err) From 884d26093c80491a3fe07f606fc04851dc317199 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 22 May 2025 18:53:31 -0700 Subject: [PATCH 25/26] llama: add minimum memory for grammar (#10820) --- llama/llama.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/llama.go b/llama/llama.go index adee6f63..0dc64e57 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -580,7 +580,7 @@ func SchemaToGrammar(schema []byte) []byte { defer C.free(unsafe.Pointer(cStr)) // Allocate buffer for grammar based on schema length but with upper bound - maxLen := min(1024*1024, len(schema)*4) + maxLen := max(32768, min(1024*1024, len(schema)*4)) buf := make([]byte, maxLen) // Call C function to convert schema to grammar From e8b981fa5d7c1875ec0c290068bcfe3b4662f5c4 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 23 May 2025 14:19:31 -0700 Subject: [PATCH 26/26] tools: refactor tool call parsing and enable streaming (#10415) --- server/model.go | 124 ---- server/model_test.go | 179 ----- server/routes.go | 68 +- .../testdata}/command-r-plus.gotmpl | 0 .../testdata}/command-r-plus.out | 0 .../testdata}/firefunction.gotmpl | 0 .../tools => tools/testdata}/firefunction.out | 0 .../testdata}/llama3-groq-tool-use.gotmpl | 0 .../testdata}/llama3-groq-tool-use.out | 0 tools/testdata/llama3.2.gotmpl | 44 ++ tools/testdata/llama3.2.out | 24 + .../tools => tools/testdata}/messages.json | 0 .../tools => tools/testdata}/mistral.gotmpl | 0 .../tools => tools/testdata}/mistral.out | 0 .../tools => tools/testdata}/nemotron.gotmpl | 0 .../tools => tools/testdata}/nemotron.out | 0 tools/testdata/qwen2.5.gotmpl | 51 ++ tools/testdata/qwen2.5.out | 31 + tools/testdata/qwen3.gotmpl | 50 ++ tools/testdata/qwen3.out | 31 + .../tools => tools/testdata}/tools.json | 0 .../tools => tools/testdata}/xlam.gotmpl | 0 .../tools => tools/testdata}/xlam.out | 0 tools/tools.go | 271 ++++++++ tools/tools_test.go | 644 ++++++++++++++++++ tools/tools_utils.go | 227 ++++++ tools/tools_utils_test.go | 464 +++++++++++++ 27 files changed, 1868 insertions(+), 340 deletions(-) delete mode 100644 server/model_test.go rename {server/testdata/tools => tools/testdata}/command-r-plus.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/command-r-plus.out (100%) rename {server/testdata/tools => tools/testdata}/firefunction.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/firefunction.out (100%) rename {server/testdata/tools => tools/testdata}/llama3-groq-tool-use.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/llama3-groq-tool-use.out (100%) create mode 100644 tools/testdata/llama3.2.gotmpl create mode 100644 tools/testdata/llama3.2.out rename {server/testdata/tools => tools/testdata}/messages.json (100%) rename {server/testdata/tools => tools/testdata}/mistral.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/mistral.out (100%) rename {server/testdata/tools => tools/testdata}/nemotron.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/nemotron.out (100%) create mode 100644 tools/testdata/qwen2.5.gotmpl create mode 100644 tools/testdata/qwen2.5.out create mode 100644 tools/testdata/qwen3.gotmpl create mode 100644 tools/testdata/qwen3.out rename {server/testdata/tools => tools/testdata}/tools.json (100%) rename {server/testdata/tools => tools/testdata}/xlam.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/xlam.out (100%) create mode 100644 tools/tools.go create mode 100644 tools/tools_test.go create mode 100644 tools/tools_utils.go create mode 100644 tools/tools_utils_test.go diff --git a/server/model.go b/server/model.go index 6b5439a4..401547e4 100644 --- a/server/model.go +++ b/server/model.go @@ -10,9 +10,6 @@ import ( "log/slog" "net/http" "os" - "slices" - "strings" - "text/template/parse" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" @@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } - -func parseObjects(s string) []map[string]any { - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { - // skip over any syntax errors - offset += int(syntax.Offset) - } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { - // skip over any unmarshalable types - offset += int(unmarshalType.Offset) - } else if err != nil { - return nil - } else { - offset += int(decoder.InputOffset()) - objs = append(objs, obj) - } - } - - return objs -} - -// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// mxyng: this only really works if the input contains tool calls in some JSON format -func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { - // create a subtree from the node that ranges over .ToolCalls - tmpl := m.Template.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl == nil { - return nil, false - } - - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return nil, false - } - - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false - } - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - } - - return all - } - - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) - } - - var toolCalls []api.ToolCall - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } - } - - return toolCalls, len(toolCalls) > 0 -} diff --git a/server/model_test.go b/server/model_test.go deleted file mode 100644 index e5c2f2bb..00000000 --- a/server/model_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" -) - -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) - if err != nil { - t.Fatal(err) - } - - return bytes.NewBuffer(bts) -} - -func TestExecuteWithTools(t *testing.T) { - p := filepath.Join("testdata", "tools") - cases := []struct { - model string - output string - ok bool - }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, - {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"command-r-plus", "Action: ```json" + ` -[ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } -] -` + "```", true}, - {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"llama3-groq-tool-use", ` -{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} -`, true}, - {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, - {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, - } - - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) - } - - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) - } - - calls := []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, - }, - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, - }, - } - - for _, tt := range cases { - t.Run(tt.model, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) - if err != nil { - t.Fatal(err) - } - - t.Run("template", func(t *testing.T) { - var actual bytes.Buffer - if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - m := &Model{Template: tmpl} - actual, ok := m.parseToolCalls(tt.output) - if ok != tt.ok { - t.Fatalf("expected %t, got %t", tt.ok, ok) - } - - if tt.ok { - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } - }) - }) - } -} - -func TestParseObjects(t *testing.T) { - tests := []struct { - input string - want []map[string]any - }{ - { - input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": `, - want: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := parseObjects(tc.input) - - if diff := cmp.Diff(got, tc.want); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - } -} diff --git a/server/routes.go b/server/routes.go index d0b8f487..42e8cdd1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/tools" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -1482,11 +1483,20 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + var toolParser *tools.Parser + if len(req.Tools) > 0 { + toolParser, err = tools.NewParser(m.Template.Template) + if err != nil { + slog.Error("failed to create tool parser", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + ch := make(chan any) go func() { defer close(ch) - var sb strings.Builder - var toolCallIndex int = 0 + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1512,37 +1522,21 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO: tool call checking and filtering should be moved outside of this callback once streaming - // however this was a simple change for now without reworking streaming logic of this (and other) - // handlers - if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { - ch <- res - return - } - - // Streaming tool calls: - // If tools are recognized, use a flag to track the sending of a tool downstream - // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ + if len(req.Tools) > 0 { + toolCalls, content := toolParser.Add(r.Content) + if len(content) > 0 { + res.Message.Content = content + } else if len(toolCalls) > 0 { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + } else { + if r.Done { + ch <- res + } + return } - res.Message.Content = "" - sb.Reset() - ch <- res - return - } - - if r.Done { - // Send any remaining content if no tool calls were detected - if toolCallIndex == 0 { - res.Message.Content = sb.String() - } - ch <- res } + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1551,11 +1545,15 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var sb strings.Builder + var toolCalls []api.ToolCall for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + if len(req.Tools) > 0 { + toolCalls = append(toolCalls, t.Message.ToolCalls...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1571,12 +1569,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls } c.JSON(http.StatusOK, resp) diff --git a/server/testdata/tools/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl similarity index 100% rename from server/testdata/tools/command-r-plus.gotmpl rename to tools/testdata/command-r-plus.gotmpl diff --git a/server/testdata/tools/command-r-plus.out b/tools/testdata/command-r-plus.out similarity index 100% rename from server/testdata/tools/command-r-plus.out rename to tools/testdata/command-r-plus.out diff --git a/server/testdata/tools/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl similarity index 100% rename from server/testdata/tools/firefunction.gotmpl rename to tools/testdata/firefunction.gotmpl diff --git a/server/testdata/tools/firefunction.out b/tools/testdata/firefunction.out similarity index 100% rename from server/testdata/tools/firefunction.out rename to tools/testdata/firefunction.out diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.gotmpl rename to tools/testdata/llama3-groq-tool-use.gotmpl diff --git a/server/testdata/tools/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.out rename to tools/testdata/llama3-groq-tool-use.out diff --git a/tools/testdata/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl new file mode 100644 index 00000000..b132423e --- /dev/null +++ b/tools/testdata/llama3.2.gotmpl @@ -0,0 +1,44 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +{{ if .System }}{{ .System }} +{{- end }} +{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities. +{{- end }}<|eot_id|> +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 }} +{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> +{{- if and $.Tools $last }} + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{{ range $.Tools }} +{{- . }} +{{ end }} +{{ .Content }}<|eot_id|> +{{- else }} + +{{ .Content }}<|eot_id|> +{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> +{{- if .ToolCalls }} +{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} +{{- else }} + +{{ .Content }} +{{- end }}{{ if not $last }}<|eot_id|>{{ end }} +{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> + +{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3.2.out b/tools/testdata/llama3.2.out new file mode 100644 index 00000000..a27c6eaf --- /dev/null +++ b/tools/testdata/llama3.2.out @@ -0,0 +1,24 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> + +22<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + +What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/server/testdata/tools/messages.json b/tools/testdata/messages.json similarity index 100% rename from server/testdata/tools/messages.json rename to tools/testdata/messages.json diff --git a/server/testdata/tools/mistral.gotmpl b/tools/testdata/mistral.gotmpl similarity index 100% rename from server/testdata/tools/mistral.gotmpl rename to tools/testdata/mistral.gotmpl diff --git a/server/testdata/tools/mistral.out b/tools/testdata/mistral.out similarity index 100% rename from server/testdata/tools/mistral.out rename to tools/testdata/mistral.out diff --git a/server/testdata/tools/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl similarity index 100% rename from server/testdata/tools/nemotron.gotmpl rename to tools/testdata/nemotron.gotmpl diff --git a/server/testdata/tools/nemotron.out b/tools/testdata/nemotron.out similarity index 100% rename from server/testdata/tools/nemotron.out rename to tools/testdata/nemotron.out diff --git a/tools/testdata/qwen2.5.gotmpl b/tools/testdata/qwen2.5.gotmpl new file mode 100644 index 00000000..cbd7302c --- /dev/null +++ b/tools/testdata/qwen2.5.gotmpl @@ -0,0 +1,51 @@ +{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> +{{- else if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen2.5.out b/tools/testdata/qwen2.5.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen2.5.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/tools/testdata/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl new file mode 100644 index 00000000..26f6656f --- /dev/null +++ b/tools/testdata/qwen3.gotmpl @@ -0,0 +1,50 @@ +{{- if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen3.out b/tools/testdata/qwen3.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen3.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/server/testdata/tools/tools.json b/tools/testdata/tools.json similarity index 100% rename from server/testdata/tools/tools.json rename to tools/testdata/tools.json diff --git a/server/testdata/tools/xlam.gotmpl b/tools/testdata/xlam.gotmpl similarity index 100% rename from server/testdata/tools/xlam.gotmpl rename to tools/testdata/xlam.gotmpl diff --git a/server/testdata/tools/xlam.out b/tools/testdata/xlam.out similarity index 100% rename from server/testdata/tools/xlam.out rename to tools/testdata/xlam.out diff --git a/tools/tools.go b/tools/tools.go new file mode 100644 index 00000000..509ca90a --- /dev/null +++ b/tools/tools.go @@ -0,0 +1,271 @@ +package tools + +import ( + "encoding/json" + "errors" + "log/slog" + "strings" + gotmpl "text/template" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +var ( + errInvalidToolCall = errors.New("invalid tool call format") + errAccumulateMore = errors.New("need to accumulate more content") +) + +type Parser struct { + parseLeadingJSON bool + prefix string + prefixFound bool + tmpl gotmpl.Template + sb strings.Builder + index int + name string + arguments string + done bool +} + +// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// +// Parameters: +// - s: The string to parse +// - name: The field name from template that identifies the tool call name +// - arguments: The field name from template that identifies the tool call arguments +// +// Returns: +// - []api.ToolCall: The parsed tool calls if successful +// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful +func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { + // Check for balanced braces before attempting to parse + braceCount := 0 + squareCount := 0 + startIndex := -1 + var rawToolCalls []string + s = strings.TrimSpace(s) + + // Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. + trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") + for i, c := range s { + switch c { + case '{': + braceCount++ + if startIndex == -1 { + startIndex = i + } + case '}': + braceCount-- + if braceCount == 0 { + rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) + startIndex = -1 + } + case '[': + if trackSquareBrackets { + squareCount++ + } + case ']': + if trackSquareBrackets { + squareCount-- + } + } + + // Negative means we have an extra closing brace/bracket + if braceCount < 0 || squareCount < 0 { + return nil, errInvalidToolCall + } + } + + // If braces/brackets aren't balanced, need more input + if braceCount > 0 || squareCount > 0 { + return nil, errAccumulateMore + } + + t := strings.TrimSpace(s) + if len(t) == 0 { + return nil, errAccumulateMore + } + // If the input is a single square bracket, it's not a valid tool call + if t[0] == '[' && len(t) == 1 { + return nil, errAccumulateMore + } + + // Attempt full unmarshal of the JSON + var toolCalls []api.ToolCall + for _, rawToolCall := range rawToolCalls { + var resp map[string]any + if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { + continue + } + + // Collect nested objects that could contain tool calls + objs := collect(resp) + if len(objs) == 0 { + continue + } + + // Extract tool calls from objects + for _, kv := range objs { + n, nok := kv[name].(string) + a, aok := kv[arguments].(map[string]any) + if nok && aok { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: n, + Arguments: a, + }, + }) + } else { + slog.Debug("No valid tool call found in object.", "object", kv) + } + } + } + + // Valid JSON, no tool calls found + if len(toolCalls) == 0 { + slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) + return nil, errInvalidToolCall + } + + return toolCalls, nil +} + +// checkPrefix processes a string to find and handle a prefix pattern. +// +// Returns: +// - The processed string with prefix removed if found +// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful +func (p *Parser) checkPrefix(s string) (string, error) { + original := s + if strings.ContainsRune(s, '\n') { + s = strings.ReplaceAll(s, "\n", " ") + } + + if s == "" || p.prefix == "" { + return s, nil + } + + // Check for prefix at start of string + if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { + // Found prefix at start - accumulate for potential tool + p.prefixFound = true + return cut, nil + } + + // Check if prefix overlaps end of string + if idx := suffixOverlap(s, p.prefix); idx != -1 { + // Return everything except overlapping portion + p.sb.Reset() + p.sb.WriteString(s[idx:]) + return original[:idx], errAccumulateMore + } + + // Check if prefix appears in middle of string + if idx := strings.Index(s, p.prefix); idx != -1 { + // Save remainder starting at prefix for next pass + p.sb.Reset() + p.sb.WriteString(strings.TrimSpace(s[idx:])) + // Return everything before prefix + return original[:idx], errAccumulateMore + } + + // No partial prefix found + return s, nil +} + +// Add processes a string input to parse tool calls and content. +// It handles prefix detection and JSON parsing to extract tool calls. +// +// Returns: +// - tools: Any parsed tool calls +// - content: Non-tool call content +func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { + if strings.TrimSpace(s) == "" { + return nil, s + } + if p.done { + if p.index == 0 { + // Return original string if no tool calls found at start + return nil, s + } + // Return empty if no tool calls found after start + return nil, "" + } + p.sb.WriteString(s) + s = p.sb.String() + + // Check for prefix pattern in input + s, err := p.checkPrefix(s) + if err != nil { + // Need more input to complete prefix + return nil, s + } + + // Exit if prefix exists in template, greedy parsing is off, and prefix not found + if !p.parseLeadingJSON && !p.prefixFound { + p.sb.Reset() + return nil, s + } + + toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) + if err != nil { + if errors.Is(err, errAccumulateMore) { + return nil, "" + } + p.sb.Reset() + // Do not try parsing leading JSON if JSON not found + p.parseLeadingJSON = false + if p.prefix == "" { + p.done = true + } + if p.index != 0 && p.prefix == "" { + return nil, "" + } + if p.prefixFound { + // Drop tokens since prefix was found + return nil, "" + } + return nil, s + } + + for _, tc := range toolCalls { + tc.Function.Index = p.index + p.index++ + } + + p.sb.Reset() + return toolCalls, "" +} + +// NewParser creates a new tool call parser from a template. It extracts the tool call format, +// prefix, and field names from the template to use for parsing tool calls from model output. +// +// Returns an error if the template does not contain valid tool call formatting. +func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { + parsed, err := template.Parse(templateToProcess.Root.String()) + if err != nil { + return nil, err + } + + tt, err := toolTemplate(parsed) + if err != nil { + return nil, err + } + + tp := toolPrefix(templateToProcess) + + name, arguments, err := extractToolArgs(tt) + if err != nil { + return nil, err + } + + return &Parser{ + tmpl: *tt, + sb: strings.Builder{}, + prefix: tp, + parseLeadingJSON: true, + name: name, + arguments: arguments, + }, nil +} diff --git a/tools/tools_test.go b/tools/tools_test.go new file mode 100644 index 00000000..1ae3bff8 --- /dev/null +++ b/tools/tools_test.go @@ -0,0 +1,644 @@ +package tools + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +func readFile(t *testing.T, base, name string) *bytes.Buffer { + t.Helper() + + bts, err := os.ReadFile(filepath.Join(base, name)) + if err != nil { + t.Fatal(err) + } + + return bytes.NewBuffer(bts) +} + +func TestParseJSONToolCalls(t *testing.T) { + tests := []struct { + name string + input string + nameField string + argsField string + wantToolCalls []api.ToolCall + wantErr error + prefix string + }{ + { + name: "valid single tool call", + input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: map[string]any{ + "arg1": "value1", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "incomplete JSON", + input: `{"name": "test_tool", "arguments": {"arg1": `, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "invalid JSON", + input: `not json at all`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "missing required fields", + input: `{"other": "field"}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "multiple tool calls in array", + input: `[ + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + ]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls without array", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls with text after", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}} text + {"name": "tool2", "arguments": {"arg2": "value"}} text + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "second tool call in array", + input: ` + , {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + // a bad JSON would not return any tool calls or content as it would always accumulate more + { + name: "unbalanced square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "incomplete square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "nested arrays in arguments", + input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) + + if err != tt.wantErr { + t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) + } + + if len(gotCalls) != 0 && tt.wantErr != nil { + t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) + } + + if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { + t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParseToolCalls(t *testing.T) { + p := filepath.Join("testdata") + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + } + t2 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "celsius", + "location": "Toronto, Canada", + }, + }, + } + + cases := []struct { + name string + model string + output string + expectedToolCall []api.ToolCall + expectedTokens string + }{ + { + name: "mistral malformed json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral tool calls with text between no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral valid json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls with text between and prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, + expectedTokens: "", + }, + { + name: "mistral incomplete json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + }, + { + name: "mistral invalid tool call with explanatory text no prefix", + model: "mistral", + output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "command r plus tool calls with json block format", + model: "command-r-plus", + output: "Action: ```json" + ` + [ + { + "tool_name": "get_current_weather", + "parameters": { + "format": "fahrenheit", + "location": "San Francisco, CA" + } + }, + { + "tool_name": "get_current_weather", + "parameters": { + "format": "celsius", + "location": "Toronto, Canada" + } + } + ] + ` + "```", + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "firefunction tool calls with functools prefix", + model: "firefunction", + output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "llama3 groq single tool call with xml tags", + model: "llama3-groq-tool-use", + output: ` + {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} + `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "xlam tool calls with wrapper object", + model: "xlam", + output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 single tool call with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "qwen2.5 multiple tool calls with and without prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 plain text response no tool calls", + model: "qwen2.5", + output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + expectedToolCall: []api.ToolCall{}, + expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + }, + { + name: "qwen2.5 tool calls with trailing text", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens after call", + }, + { + name: "qwen2.5 tool calls with initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "qwen2.5 tool calls with prefix and trailing text", + model: "qwen2.5", + output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls with prefix and initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens before call", + }, + { + name: "qwen2.5 tool calls without and with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls without and with prefix and text between", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens between", + }, + { + name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", + model: "qwen2.5", + output: `hi [{"options": "foo"}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `hi [{"options": "foo"}]`, + }, + { + name: "qwen2.5 tool calls with prefix and invalid tool call", + model: "qwen2.5", + output: ` [{"options": "foo"}] `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", + model: "qwen3", + output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", + model: "qwen3", + output: `Okay, let me think what tool we should use... { "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 empty think prefix without tool prefix and invalid tool call", + model: "qwen3", + output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 empty think prefix with tool prefix and valid tool call", + model: "qwen3", + output: `{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", + model: "qwen3", + output: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with malformed tool prefix", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "model with prefix in template, no prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model with prefix in template, prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, single tool call", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "model without prefix in template, prefix in output", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, no prefix in output, tokens before", + model: "qwen2.5", + output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, prefix in output, tokens after", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens after", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens before", + model: "llama3.2", + output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model without prefix in template, prefix in output, tokens after", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + }, + } + + var tools []api.Tool + if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { + t.Fatal(err) + } + + var messages []api.Message + if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { + t.Fatal(err) + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + if err != nil { + t.Fatal(err) + } + + t.Run("template", func(t *testing.T) { + actual := &bytes.Buffer{} // Create new buffer for each test + if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("parse", func(t *testing.T) { + tp, err := NewParser(tmpl.Template) + if err != nil { + t.Fatal(err) + } + got := []api.ToolCall{} + var gotTokens strings.Builder + + tokens := strings.Fields(tt.output) + for _, tok := range tokens { + s := " " + tok + + toolCalls, content := tp.Add(s) + if len(content) > 0 { + gotTokens.WriteString(content) + } else if len(toolCalls) > 0 { + got = append(got, toolCalls...) + } + } + + // Compare tool calls if we expect any + if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { + t.Errorf("tool calls mismatch (-got +want):\n%s", diff) + } + + // Compare tokens if we expect any + stripped := strings.TrimSpace(gotTokens.String()) + if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { + t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) + t.Errorf("tokens mismatch (-got +want):\n%s", diff) + } + }) + }) + } +} diff --git a/tools/tools_utils.go b/tools/tools_utils.go new file mode 100644 index 00000000..48531b78 --- /dev/null +++ b/tools/tools_utils.go @@ -0,0 +1,227 @@ +package tools + +import ( + "bytes" + "encoding/json" + "errors" + "log/slog" + "slices" + "strings" + gotmpl "text/template" + "text/template/parse" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. +// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any +// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. +// +// Returns: +// - string: The extracted text following the first ".ToolCalls" condition found +// - bool: Whether a ".ToolCalls" condition was found in the template +func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if isToolCallsNode(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + } + } + + walk(tmpl.Tree.Root.Nodes) + return result, found +} + +// isToolCallsNode detects if a node's condition includes ".ToolCalls" +func isToolCallsNode(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +func toolPrefix(tmpl *gotmpl.Template) string { + tokenText, ok := extractToolCallsFormat(tmpl) + if !ok { + return "" + } + tokenText = strings.TrimSpace(tokenText) + tokenText = strings.ReplaceAll(tokenText, "\r", "") + tokenText = strings.ReplaceAll(tokenText, "\n", " ") + + return tokenText +} + +// toolTemplate creates a subtree from the node that ranges over .ToolCalls +// +// Returns: +// - *gotmpl.Template: The subtree containing the .ToolCalls range +// - error: Error if parsing failed +func toolTemplate(t *template.Template) (*gotmpl.Template, error) { + tmpl := t.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, errors.New("failed to find tool template") + } + + return tmpl, nil +} + +// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins +// +// Returns: +// - int: The starting index in s where the suffix overlap begins +func suffixOverlap(s, prefix string) int { + max := min(len(prefix), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, prefix[:i]) { + return len(s) - i + } + } + return -1 +} + +// extractToolArgs executes a template with a known tool call format to extract the name and arguments +// +// Returns: +// - string: The name of the tool call +// - string: The arguments of the tool call +// - error: Error if parsing failed +func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, + }, + }, + }, + }); err != nil { + return "", "", err + } + + var obj any + err = json.Unmarshal(b.Bytes(), &obj) + if err != nil { + return "", "", err + } + + var objs []map[string]any + switch v := obj.(type) { + case map[string]any: + objs = []map[string]any{v} + case []map[string]any: + objs = v + case []any: + objs = collect(v) + } + if len(objs) == 0 { + return "", "", errors.New("no template objects found") + } + + // find the keys that correspond to the name and arguments fields + for k, v := range objs[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) + return "", "", errors.New("missing required fields in tool call template") + } + + return name, arguments, nil +} + +// collect recursively traverses an object to collect all nested maps +// +// Returns: +// - []map[string]any: A slice of all nested maps found in the object +func collect(obj any) []map[string]any { + var all []map[string]any + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) + } + default: + return nil + } + + return all +} diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go new file mode 100644 index 00000000..769183b7 --- /dev/null +++ b/tools/tools_utils_test.go @@ -0,0 +1,464 @@ +package tools + +import ( + "testing" + gotmpl "text/template" + + "github.com/ollama/ollama/template" +) + +func TestExtractToolCallsFormat(t *testing.T) { + cases := []struct { + name string + template string + want string + found bool + }{ + { + name: "nil template", + template: "", + want: "", + found: false, + }, + { + name: "basic tool call with text", + template: "{{if .ToolCalls}}Hello world{{end}}", + want: "Hello world", + found: true, + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json\n", + found: true, + }, + { + name: "tool call in range", + template: "{{range .ToolCalls}}tool: {{.}}{{end}}", + want: "", + found: false, + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + found: true, + }, + { + name: "nested if without tool calls", + template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", + want: "", + found: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got, found := extractToolCallsFormat(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + if found != tc.found { + t.Errorf("got found %v, want %v", found, tc.found) + } + }) + } +} + +func TestToolPrefix(t *testing.T) { + cases := []struct { + name string + template string + want string + }{ + { + name: "basic tool call with action prefix", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action: ```json", + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools[", + }, + { + name: "tool call with angle brackets", + template: "{{if .ToolCalls}}Hello, world! {{end}}", + want: "Hello, world! ", + }, + { + name: "multiple tool call formats", + template: "{{if .ToolCalls}}[tool_call] {{end}}", + want: "[tool_call] ", + }, + { + name: "single angle bracket tool call", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + { + name: "incomplete angle bracket after tool call", + template: "{{if .ToolCalls}}[tool_call] <{{end}}", + want: "[tool_call] <", + }, + { + name: "angle bracket prefix with tool call", + template: "{{if .ToolCalls}}> {{end}}", + want: "> ", + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL] [", + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL][", + }, + { + name: "tool call with pipe delimiters", + template: "{{if .ToolCalls}}<|tool_call|>{{end}}", + want: "<|tool_call|>", + }, + { + name: "tool with no prefix", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + got := toolPrefix(tmpl) + if got != tt.want { + t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) + } + }) + } +} + +func TestToolTemplate(t *testing.T) { + cases := []struct { + name string + template string + want bool + }{ + { + name: "basic tool call range", + template: "{{range .ToolCalls}}test{{end}}", + want: true, + }, + { + name: "no tool calls", + template: "{{range .Other}}test{{end}}", + want: false, + }, + { + name: "nested tool calls", + template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", + want: true, + }, + { + name: "empty template", + template: "", + want: false, + }, + { + name: "tool calls in if statement", + template: "{{if .ToolCalls}}test{{end}}", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + parsed, err := template.Parse(tmpl.Root.String()) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + _, err = toolTemplate(parsed) + if err != nil && tt.want { + t.Errorf("toolTemplate() = %v; want %v", err, tt.want) + } + }) + } +} + +func TestSuffixOverlap(t *testing.T) { + cases := []struct { + name string + s string + d string + want int + }{ + { + name: "no overlap", + s: "hello world", + d: "", + want: -1, + }, + { + name: "full overlap", + s: "", + d: "", + want: 0, + }, + { + name: "partial overlap", + s: "text ", + d: "", + want: 5, + }, + { + name: "delimiter longer than string", + s: "", + d: "", + want: -1, + }, + { + name: "empty string", + s: "", + d: "", + want: -1, + }, + { + name: "empty delimiter", + s: "", + d: "", + want: -1, + }, + { + name: "single char overlap", + s: "test<", + d: "", + want: 4, + }, + { + name: "partial tool call", + s: "hello ", + want: 6, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := suffixOverlap(tt.s, tt.d) + if got != tt.want { + t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) + } + }) + } +} + +func TestExtractToolArgs(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool + }{ + { + name: "basic tool call with text after", + template: `{{if .ToolCalls}}tool response{{end}}`, + want: "tool response", + ok: true, + }, + { + name: "tool call with mixed content after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool call with no text after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "nested tool call", + template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "no tool calls", + template: `{{if .Something}}no tools here{{end}}`, + want: "", + ok: false, + }, + { + name: "empty template", + template: ``, + want: "", + ok: false, + }, + { + name: "multiple tool calls sections", + template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, + want: "first", + ok: true, + }, + { + name: "range over tool calls", + template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with pipe delimiters", + template: `{{if .ToolCalls}}<|tool|>{{end}}`, + want: "<|tool|>", + ok: true, + }, + { + name: "tool calls with nested template", + template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with whitespace variations", + template: `{{if .ToolCalls}} tool {{end}}`, + want: " tool ", + ok: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + got, ok := extractToolCallsFormat(tmpl) + if got != tt.want { + t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + } + if ok != tt.ok { + t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + } + }) + } +} + +func TestCollect(t *testing.T) { + cases := []struct { + name string + obj any + want []map[string]any + }{ + { + name: "simple map", + obj: map[string]any{ + "key": "value", + }, + want: []map[string]any{ + {"key": "value"}, + }, + }, + { + name: "nested map", + obj: map[string]any{ + "outer": map[string]any{ + "inner": "value", + }, + }, + want: []map[string]any{ + {"outer": map[string]any{"inner": "value"}}, + {"inner": "value"}, + }, + }, + { + name: "array of maps", + obj: []any{ + map[string]any{"key1": "val1"}, + map[string]any{"key2": "val2"}, + }, + want: []map[string]any{ + {"key1": "val1"}, + {"key2": "val2"}, + }, + }, + { + name: "deeply nested", + obj: map[string]any{ + "l1": map[string]any{ + "l2": map[string]any{ + "l3": "value", + }, + }, + }, + want: []map[string]any{ + {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, + {"l2": map[string]any{"l3": "value"}}, + {"l3": "value"}, + }, + }, + { + name: "non-map value", + obj: "string", + want: nil, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := collect(tt.obj) + if len(got) != len(tt.want) { + t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) + return + } + + // Compare each map in the result + for i := range tt.want { + if !mapsEqual(got[i], tt.want[i]) { + t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +// mapsEqual compares two maps for deep equality +func mapsEqual(m1, m2 map[string]any) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + v2, ok := m2[k] + if !ok { + return false + } + switch val1 := v1.(type) { + case map[string]any: + val2, ok := v2.(map[string]any) + if !ok || !mapsEqual(val1, val2) { + return false + } + default: + if v1 != v2 { + return false + } + } + } + return true +}