diff --git a/CMakeLists.txt b/CMakeLists.txt index 16526153..e170e66f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) +add_compile_definitions(NDEBUG) + set(GGML_CPU ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) diff --git a/README.md b/README.md index 0a51988e..4e8b7c57 100644 --- a/README.md +++ b/README.md @@ -427,6 +427,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) - [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers) +- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI) ### Cloud diff --git a/cmd/cmd.go b/cmd/cmd.go index ad4be7f9..b9047529 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1236,11 +1236,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { return err } if err := client.Heartbeat(cmd.Context()); err != nil { - if !strings.Contains(err.Error(), " refused") { + if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) { return err } if err := startApp(cmd.Context(), client); err != nil { - return errors.New("could not connect to ollama app, is it running?") + return fmt.Errorf("ollama server not responding - %w", err) } } return nil diff --git a/cmd/start_windows.go b/cmd/start_windows.go index 5bca2433..bcc51057 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -4,17 +4,27 @@ import ( "context" "errors" "fmt" + "log/slog" "os" "os/exec" + "path" "path/filepath" "strings" "syscall" + "unsafe" "github.com/ollama/ollama/api" + "golang.org/x/sys/windows" +) + +const ( + Installer = "OllamaSetup.exe" ) func startApp(ctx context.Context, client *api.Client) error { - // log.Printf("XXX Attempting to find and start ollama app") + if len(isProcRunning(Installer)) > 0 { + return fmt.Errorf("upgrade in progress...") + } AppName := "ollama app.exe" exe, err := os.Executable() if err != nil { @@ -56,3 +66,41 @@ func startApp(ctx context.Context, client *api.Client) error { } return waitForServer(ctx, client) } + +func isProcRunning(procName string) []uint32 { + pids := make([]uint32, 2048) + var ret uint32 + if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 { + slog.Debug("failed to check for running installers", "error", err) + return nil + } + pids = pids[:ret] + var matches []uint32 + for _, pid := range pids { + if pid == 0 { + continue + } + hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid) + if err != nil { + continue + } + defer windows.CloseHandle(hProcess) + var module windows.Handle + var cbNeeded uint32 + cb := (uint32)(unsafe.Sizeof(module)) + if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil { + continue + } + var sz uint32 = 1024 * 8 + moduleName := make([]uint16, sz) + cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0))) + if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { + continue + } + exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName))) + if strings.EqualFold(exeFile, procName) { + matches = append(matches, pid) + } + } + return matches +} diff --git a/convert/convert.go b/convert/convert.go index 309b0ce1..4a6df66c 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -53,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV { } for _, sv := range t.SpecialVocabulary { - kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken + kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) + if len(sv.IDs) > 0 { + kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs + } } return kv diff --git a/convert/convert_mllama.go b/convert/convert_mllama.go index 12478be7..69d7f588 100644 --- a/convert/convert_mllama.go +++ b/convert/convert_mllama.go @@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { var out []*ggml.Tensor var text []Tensor for _, t := range ts { - if t.Name() == "v.position_embd.gate" { + if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") { + text = append(text, t) + } else if t.Name() == "v.position_embd.gate" { for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} { tt := t.Clone() tt.SetRepacker(m.repack(name)) @@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { WriterTo: tt, }) } - } else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" { - t.SetRepacker(m.repack(t.Name())) - out = append(out, &ggml.Tensor{ - Name: t.Name(), - Kind: t.Kind(), - Shape: t.Shape(), - WriterTo: t, - }) - } else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") { - out = append(out, &ggml.Tensor{ - Name: t.Name(), - Kind: t.Kind(), - Shape: t.Shape(), - WriterTo: t, - }) } else { - text = append(text, t) + if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" { + t.SetRepacker(m.repack(t.Name())) + } else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") { + t.SetRepacker(m.repack(t.Name())) + } else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") { + t.SetRepacker(m.repack(t.Name())) + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) } } @@ -137,16 +137,35 @@ func (m *mllamaModel) repack(name string) Repacker { var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) - t, err = tensor.Tanh(t) - if err != nil { - return nil, err - } + if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") { + heads := m.VisionModel.AttentionHeads + if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } - if name == "v.position_embd.gate" { - t, err = tensor.Sub(float32(1), t) + if err := t.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := t.Reshape(dims...); err != nil { + return nil, err + } + + if err := t.Transpose(); err != nil { + return nil, err + } + } else { + t, err = tensor.Tanh(t) if err != nil { return nil, err } + + if name == "v.position_embd.gate" { + t, err = tensor.Sub(float32(1), t) + if err != nil { + return nil, err + } + } } t = tensor.Materialize(t) diff --git a/convert/convert_test.go b/convert/convert_test.go index b9db6fa1..105fbb3d 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) { } t.Cleanup(func() { r.Close() }) - m, _, err := ggml.Decode(r, -1) + m, err := ggml.Decode(r, -1) if err != nil { t.Fatal(err) } @@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) { } defer r.Close() - m, _, err := ggml.Decode(r, -1) + m, err := ggml.Decode(r, -1) if err != nil { t.Fatal(err) } diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 74e2efed..bedcd4f8 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -110,6 +110,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { + // noop } else if err != nil { return nil, err } else { @@ -171,6 +172,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } } + if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) { + } else if err != nil { + return nil, err + } else { + defer f.Close() + + var p map[string]json.RawMessage + if err := json.NewDecoder(f).Decode(&p); err != nil { + return nil, err + } + + for _, st := range specialTokenTypes { + if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok { + var ids []int32 + if err := json.Unmarshal(bts, &ids); err != nil { + // value is not a list so the existing ID is used + continue + } + + if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool { + return sv.Type == st + }); i >= 0 { + t.SpecialVocabulary[i].IDs = ids + } + } + } + } + return t, nil } @@ -280,6 +309,9 @@ type SpecialVocabulary struct { ID int Content string AddToken bool + + // IDs is populated by generation_config.json + IDs []int32 } func (sv SpecialVocabulary) Key() string { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index c6ef9732..813096fd 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "generation config eos token ids", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 0, + "content": "", + "special": true + }, + { + "id": 1, + "content": "", + "special": true + }, + { + "id": 2, + "content": "", + "special": true + }, + { + "id": 3, + "content": "", + "special": true + } + ], + "model": { + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3 + } + } + }`), + "tokenizer_config.json": strings.NewReader(`{ + "add_bos_token": true, + "add_eos_token": false, + "bos_token": "", + "eos_token": "" + }`), + "generation_config.json": strings.NewReader(`{ + "bos_token_id": 0, + "eos_token_id": [1, 2, 3] + }`), + }), + specialTokenTypes: []string{"pad", "eos", "bos", "unk"}, + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{"", "", "", ""}, + Scores: []float32{0, 1, 2, 3}, + Types: []int32{3, 3, 3, 3}, + }, + SpecialVocabulary: []*SpecialVocabulary{ + {Type: "eos", Content: "", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false}, + {Type: "bos", Content: "", ID: 0, AddToken: true}, + }, + Pre: "default", + }, + }, } for _, tt := range cases { diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 8c0a2ae5..aa85aec2 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -15,6 +15,7 @@ import ( type GGML struct { container model + Length int64 } type model interface { @@ -386,12 +387,12 @@ func DetectContentType(b []byte) string { // // It collects array values for arrays with a size less than or equal to // maxArraySize. If the maxArraySize is negative, all arrays are collected. -func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { +func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { rs = bufioutil.NewBufferedSeeker(rs, 32<<10) var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { - return nil, 0, err + return nil, err } var c container @@ -401,24 +402,25 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { case FILE_MAGIC_GGUF_BE: c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} default: - return nil, 0, errors.New("invalid file magic") + return nil, errors.New("invalid file magic") } model, err := c.Decode(rs) if err != nil { - return nil, 0, err + return nil, err } offset, err := rs.Seek(0, io.SeekCurrent) if err != nil { - return nil, 0, err + return nil, err } // final model type return &GGML{ container: c, model: model, - }, offset, nil + Length: offset, + }, nil } func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go index 10d3b684..0e071800 100644 --- a/fs/ggml/gguf_test.go +++ b/fs/ggml/gguf_test.go @@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) { } defer r.Close() - ff, _, err := Decode(r, 0) + ff, err := Decode(r, 0) if err != nil { t.Fatal(err) } diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index b9726c8f..bbd031a9 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) { } testCases := []testCase{ { - model: "llava:7b", + model: "qwen2.5vl", }, { model: "llama3.2-vision", @@ -60,6 +60,7 @@ func TestVisionModels(t *testing.T) { } func TestIntegrationSplitBatch(t *testing.T) { + skipUnderMinVRAM(t, 6) image, err := base64.StdEncoding.DecodeString(imageEncoding) require.NoError(t, err) req := api.GenerateRequest{ diff --git a/kvcache/causal.go b/kvcache/causal.go index 9bc1d5da..f6bacaaf 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curCellRange.max = len(c.cells) - 1 } - var err error - c.curMask, err = c.buildMask(ctx) + c.curMask = c.buildMask(ctx) - return err + return nil } func newRange() cellRange { @@ -297,7 +296,7 @@ func roundUp(length, pad int) int { // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). -func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { +func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { // Align and pad the two dimensions as required by the backend batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) @@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask[i] = float32(math.Inf(-1)) } - maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) - if err != nil { - return nil, err - } + maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) if c.config.MaskDType != ml.DTypeF32 { out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) @@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { maskTensor = out } - return maskTensor, nil + return maskTensor } func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { @@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { if !slices.Equal(c.opts.Except, opts.Except) { c.opts = opts if ctx != nil { - var err error - c.curMask, err = c.buildMask(ctx) - if err != nil { - // This error should never occur because we have previously built a mask with the same shape - panic(fmt.Errorf("SetCausal: %w", err)) - } + c.curMask = c.buildMask(ctx) } } } @@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } } - kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) - if err != nil { - return err - } + kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) for i, key := range c.keys { if key == nil { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 79698708..5b1dbe86 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice(test.in, test.inShape...) + tensor := context.FromFloatSlice(test.in, test.inShape...) cache.Put(context, tensor, tensor) out, _, mask := cache.Get(context) @@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return c.Empty(dtype, shape...) } -func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) - return t, nil + return t } -func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { f := make([]float32, len(s)) for i := range f { f[i] = float32(s[i]) } - out, _ := c.FromFloatSlice(f, shape...) + out := c.FromFloatSlice(f, shape...) out.(*testTensor).dtype = ml.DTypeI32 - return out, nil + return out } func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso s = append(s, i) } - out, _ := c.FromFloatSlice(s, len(s)) + out := c.FromFloatSlice(s, len(s)) out.(*testTensor).dtype = dtype return out } @@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} -func (c *testContext) Reserve() error { return nil } +func (c *testContext) Reserve() {} func (c *testContext) MaxGraphNodes() int { return 10 diff --git a/llama/llama.go b/llama/llama.go index 1251be3a..0dc64e57 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -544,7 +544,7 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, cparams.penalty_last_n = C.int32_t(params.RepeatLastN) cparams.penalty_repeat = C.float(params.PenaltyRepeat) cparams.penalty_freq = C.float(params.PenaltyFreq) - cparams.penalty_present = C.float(params.PenaltyFreq) + cparams.penalty_present = C.float(params.PenaltyPresent) cparams.seed = C.uint32_t(params.Seed) grammar := C.CString(params.Grammar) @@ -580,7 +580,7 @@ func SchemaToGrammar(schema []byte) []byte { defer C.free(unsafe.Pointer(cStr)) // Allocate buffer for grammar based on schema length but with upper bound - maxLen := min(1024*1024, len(schema)*4) + maxLen := max(32768, min(1024*1024, len(schema)*4)) buf := make([]byte, maxLen) // Call C function to convert schema to grammar @@ -602,7 +602,7 @@ type Grammar struct { mu sync.Mutex } -func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar { +func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar { cGrammar := C.CString(grammar) defer C.free(unsafe.Pointer(cGrammar)) @@ -622,7 +622,7 @@ func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogToke cEogTokens[i] = C.uint32_t(token) } - g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens))) + g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens))) if g == nil { return nil } diff --git a/llama/patches/0016-graph-memory-reporting-on-failure.patch b/llama/patches/0016-graph-memory-reporting-on-failure.patch new file mode 100644 index 00000000..92188224 --- /dev/null +++ b/llama/patches/0016-graph-memory-reporting-on-failure.patch @@ -0,0 +1,156 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jesse Gross +Date: Fri, 18 Apr 2025 15:58:19 -0700 +Subject: [PATCH] graph memory reporting on failure + +--- + ggml/include/ggml-alloc.h | 6 ++++++ + ggml/include/ggml-backend.h | 6 ++++++ + ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++---- + ggml/src/ggml-backend.cpp | 10 ++++++++++ + 4 files changed, 56 insertions(+), 4 deletions(-) + +diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h +index 2cb150fd..781b1e10 100644 +--- a/ggml/include/ggml-alloc.h ++++ b/ggml/include/ggml-alloc.h +@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph + + GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); + ++struct ggml_allocr_buffer_status { ++ size_t size; ++ bool allocated; ++}; ++GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); ++ + // Utils + // Create a buffer and allocate all the tensors in a ggml_context + GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index 778927f6..74e46716 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -304,6 +304,12 @@ extern "C" { + + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + ++ struct ggml_backend_buffer_status { ++ size_t size; ++ bool allocated; ++ }; ++ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); ++ + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + +diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c +index 5fd379f6..04812990 100644 +--- a/ggml/src/ggml-alloc.c ++++ b/ggml/src/ggml-alloc.c +@@ -364,6 +364,7 @@ struct node_alloc { + struct ggml_gallocr { + ggml_backend_buffer_type_t * bufts; // [n_buffers] + ggml_backend_buffer_t * buffers; // [n_buffers] ++ size_t *buffer_sizes; // [n_buffers] + struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] + int n_buffers; + +@@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs + galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); + GGML_ASSERT(galloc->buffers != NULL); + ++ galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); ++ GGML_ASSERT(galloc->buffer_sizes != NULL); ++ + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); + GGML_ASSERT(galloc->buf_tallocs != NULL); + +@@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { + ggml_hash_set_free(&galloc->hash_set); + free(galloc->hash_values); + free(galloc->bufts); ++ free(galloc->buffer_sizes); + free(galloc->buffers); + free(galloc->buf_tallocs); + free(galloc->node_allocs); +@@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c + } + } + ++ bool success = true; ++ + // reallocate buffers if needed + for (int i = 0; i < galloc->n_buffers; i++) { + // if the buffer type is used multiple times, we reuse the same buffer +@@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c + + ggml_backend_buffer_free(galloc->buffers[i]); + galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); +- if (galloc->buffers[i] == NULL) { ++ if (galloc->buffers[i]) { ++ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); ++ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); ++ } else { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); +- return false; ++ galloc->buffer_sizes[i] = new_size; ++ success = false; + } +- ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); ++ } else { ++ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); + } + } + +- return true; ++ return success; + } + + bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { +@@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); + } + ++struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { ++ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); ++ ++ for (int i = 0; i < buffer_id; i++) { ++ if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) { ++ // This buffer is the same as a previous one due to the same buffer type being used multiple times ++ // (See above.) However, we need a different check because multiple buffers might be NULL in our ++ // case and we still want to know the attempted size. ++ ++ struct ggml_allocr_buffer_status status = {0, true}; ++ return status; ++ } ++ } ++ ++ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; ++ return status; ++} ++ + // utils + + static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { +diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp +index 0ce73a99..be335e8c 100644 +--- a/ggml/src/ggml-backend.cpp ++++ b/ggml/src/ggml-backend.cpp +@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe + return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); + } + ++struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { ++ int backend_index = ggml_backend_sched_backend_id(sched, backend); ++ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); ++ ++ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; ++ ++ return status; ++} ++ + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); diff --git a/llm/memory.go b/llm/memory.go index b5a8dd5c..05b3b2fd 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -1,12 +1,9 @@ package llm import ( - "cmp" "fmt" "log/slog" - "maps" "os" - "slices" "strconv" "strings" @@ -85,8 +82,11 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin var graphOffload uint64 // Projectors loaded into GPU0 only - var projectorWeights uint64 - var projectorGraph uint64 + var llamaEngineProjectorWeights uint64 + + // Projectors loaded with output layer + var ollamaEngineProjectorWeights uint64 + var ollamaEngineProjectorGraph uint64 // Conditional output size on GPU 0 var memoryLayerOutput uint64 @@ -111,21 +111,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList) for _, projector := range projectors { - weight := projectorMemoryRequirements(projector) - projectorWeights += weight + llamaEngineProjectorWeights += projectorMemoryRequirements(projector) // multimodal models require at least 2048 context opts.NumCtx = max(opts.NumCtx, 2048) } - if projectorWeights == 0 && projectorGraph == 0 { - projectorWeights, projectorGraph = f.VisionGraphSize() + if llamaEngineProjectorWeights == 0 { + ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize() + opts.NumCtx = max(opts.NumCtx, 2048) } layers := f.Tensors().GroupLayers() - // add one layer (chosing the max layer) worth of memory as a buffer - layerSize = slices.MaxFunc(slices.Collect(maps.Values(layers)), func(a, b ggml.Layer) int { - return cmp.Compare(a.Size(), b.Size()) - }).Size() + // add one layer worth of memory as a buffer + if blk0, ok := layers["blk.0"]; ok { + layerSize = blk0.Size() + } else { + slog.Warn("model missing blk.0 layer size") + } var kvct string if envconfig.FlashAttention() && @@ -163,6 +165,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin graphFullOffload = graphPartialOffload } + // Output layer handled at the end if we have space if layer, ok := layers["output_norm"]; ok { memoryLayerOutput += layer.Size() } @@ -172,8 +175,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin memoryLayerOutput += layer.Size() } - // Output layer handled at the end if we have space - gpuZeroOverhead := projectorWeights + projectorGraph + gpuZeroOverhead := llamaEngineProjectorWeights // Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer var layerCount int @@ -216,6 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin if len(gpusWithSpace) > 0 { gpuZeroID = gpusWithSpace[0].i gpuAllocations[gpuZeroID] += gpuZeroOverhead + } else { + overflow += gpuZeroOverhead } // For all the layers, find where they can fit on the GPU(s) @@ -256,21 +260,24 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // Determine if we need to consider output then find where it fits - if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) { - for j := len(gpusWithSpace); j > 0; j-- { - g := gpusWithSpace[layerCount%j] - used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) - if g.g.FreeMemory > overhead+used+memoryLayerOutput { - gpuAllocations[g.i] += memoryLayerOutput - layerCounts[g.i]++ - layerCount++ - break + memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph + if memoryLastLayer > 0 { + if opts.NumGPU < 0 || layerCount < opts.NumGPU { + for j := len(gpusWithSpace); j > 0; j-- { + g := gpusWithSpace[layerCount%j] + used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) + if g.g.FreeMemory > overhead+used+memoryLastLayer { + gpuAllocations[g.i] += memoryLastLayer + layerCounts[g.i]++ + layerCount++ + break + } } } if layerCount < int(f.KV().BlockCount())+1 { fullyLoaded = false - overflow += memoryLayerOutput + overflow += memoryLastLayer } } @@ -328,8 +335,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin memoryLayerOutput: memoryLayerOutput, graphFullOffload: graphFullOffload, graphPartialOffload: graphPartialOffload, - projectorWeights: projectorWeights, - projectorGraph: projectorGraph, + projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights, + projectorGraph: ollamaEngineProjectorGraph, } if gpus[0].Library == "cpu" { @@ -415,7 +422,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) { } defer file.Close() - ggml, _, err := ggml.Decode(file, 1024) + ggml, err := ggml.Decode(file, 1024) if err != nil { return 0 } diff --git a/llm/server.go b/llm/server.go index c07315fa..4abb569f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { } defer f.Close() - ggml, _, err := ggml.Decode(f, maxArraySize) + ggml, err := ggml.Decode(f, maxArraySize) return ggml, err } diff --git a/ml/backend.go b/ml/backend.go index cb32d818..6beb7d2b 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "fmt" "math" - "os" "slices" "strconv" "strings" @@ -15,6 +14,11 @@ import ( ) type Backend interface { + Load(ctx context.Context, progress func(float32)) error + + // BackendMemory returns the memory allocations that were made for this model + BackendMemory() BackendMemory + Config() fs.Config Get(name string) Tensor NewContext() Context @@ -52,10 +56,6 @@ type CacheConfig struct { // BackendParams controls how the backend loads and executes models type BackendParams struct { - // Progress is a callback function that allows reporting percentage completion - // of model loading - Progress func(float32) - // NumThreads sets the number of threads to use if running on the CPU NumThreads int @@ -72,9 +72,87 @@ type BackendParams struct { FlashAttention bool } -var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error)) +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} -func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) { +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +type AllocationStatus int + +const ( + // Unallocated memory - have not yet attempted to allocate + Unallocated AllocationStatus = iota + + // Failed memory - tried to allocate the memory and did not succeed + Failed + + // Allocated memory = tried and succeeded to allocate memory + Allocated +) + +// Memory is the size of an allocation and whether it was successful. +type Memory struct { + Size uint64 + Status AllocationStatus +} + +func (m Memory) String() string { + s := fmt.Sprint(m.Size) + + switch m.Status { + case Unallocated: + s += "U" + case Failed: + s += "F" + case Allocated: + s += "A" + } + + return s +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []Memory + + // Cache is the per-layer memory needed for the KV cache. + Cache []Memory + + // Graph is the size of the compute graph. It is not per-layer. + Graph Memory +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputsWeights are always located on the CPU and cannot be moved + InputWeights Memory + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + +var backends = make(map[string]func(string, BackendParams) (Backend, error)) + +func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { if _, ok := backends[name]; ok { panic("backend: backend already registered") } @@ -82,9 +160,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam backends[name] = f } -func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) { +func NewBackend(modelPath string, params BackendParams) (Backend, error) { if backend, ok := backends["ggml"]; ok { - return backend(ctx, f, params) + return backend(modelPath, params) } return nil, fmt.Errorf("unsupported backend") @@ -93,8 +171,8 @@ func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, type Context interface { Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor - FromFloatSlice(s []float32, shape ...int) (Tensor, error) - FromIntSlice(s []int32, shape ...int) (Tensor, error) + FromFloatSlice(s []float32, shape ...int) Tensor + FromIntSlice(s []int32, shape ...int) Tensor // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. Arange(start, stop, step float32, dtype DType) Tensor @@ -106,7 +184,7 @@ type Context interface { // graph, simply preallocates memory. Typically called with a // worst case graph to ensure all resources are available for // for future inference. - Reserve() error + Reserve() MaxGraphNodes() int Close() @@ -119,21 +197,6 @@ type Context interface { Layer(int) Context } -// RopeOptions contains optional parameters for RoPE function -type RopeOptions struct { - OriginalContextLen uint32 -} - -// RopeOption defines a function that modifies RopeOpts -type RopeOption func(*RopeOptions) - -// WithContextLen sets a custom context length -func WithContextLen(len uint32) RopeOption { - return func(opts *RopeOptions) { - opts.OriginalContextLen = len - } -} - type Tensor interface { Dim(n int) int Stride(n int) int @@ -147,6 +210,8 @@ type Tensor interface { Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor + Div(ctx Context, t2 Tensor) Tensor + Mulmat(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatID(ctx Context, t2, ids Tensor) Tensor @@ -155,11 +220,11 @@ type Tensor interface { LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor + SumRows(ctx Context) Tensor AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Sin(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2821ad11..76172ae1 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,7 +10,6 @@ import "C" import ( "context" - "errors" "fmt" "io" "log/slog" @@ -30,6 +29,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "github.com/ollama/ollama/ml/nn/rope" "golang.org/x/sync/errgroup" ) @@ -44,8 +44,15 @@ func devices() []*C.struct_ggml_backend_device { } type Backend struct { + // modelPath is the location of the model data + modelPath string + meta *fsggml.GGML + // tensorLoadTargets maps from the name of the tensor in the file + // to the name that is used by the model definition + tensorLoadTargets map[string][]string + sched *C.struct_ggml_backend_sched schedBackends []*C.struct_ggml_backend schedBufts []*C.struct_ggml_backend_buffer_type @@ -58,14 +65,26 @@ type Backend struct { // layers is the backend used for repeating layers layers map[int]*C.struct_ggml_backend_buffer_type + // requiredMemory is the cumulative memory allocations needed by the backend + requiredMemory *ml.BackendMemory + + // btDeviceMemory maps from a buffer type to the memory allocations associated with that device + btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory + flashAttention bool // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler maxGraphNodes int } -func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { - meta, n, err := fsggml.Decode(r, -1) +func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { + r, err := os.Open(modelPath) + if err != nil { + return nil, err + } + defer r.Close() + + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } @@ -80,6 +99,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, "num_key_values", len(meta.KV()), ) + var requiredMemory ml.BackendMemory + btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory) + type deviceBufferType struct { d *C.struct_ggml_backend_device bts []*C.struct_ggml_backend_buffer_type @@ -100,6 +122,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } } + blocks := int(meta.KV().BlockCount()) + // create list of buffer types for the cpu cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)} for _, d := range append(accels, append(gpus, cpus...)...) { @@ -107,17 +131,27 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, case C.GGML_BACKEND_DEVICE_TYPE_CPU, C.GGML_BACKEND_DEVICE_TYPE_ACCEL: cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) + btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } } + requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d)) + requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) + requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) + // create list of buffer types for each gpu var gpuDeviceBufferTypes []deviceBufferType - for _, d := range gpus { + requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus)) + for i, d := range gpus { bt := C.ggml_backend_dev_buffer_type(d) gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ d: d, bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...), }) + btDeviceMemory[bt] = &requiredMemory.GPUs[i] + requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) + requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) + requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) } useDefaultSplit := true @@ -156,8 +190,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, // inputs always use cpu input := cpuDeviceBufferType - blocks := int(meta.KV().BlockCount()) - // define a range of gpu layers. anything outside of this range is assigned to the cpu gpuRangeStart := max(0, blocks-params.NumGPULayers) gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) @@ -198,7 +230,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, // contexts are shared by tensors of the same buffer type ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) - createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { + createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor { for _, bt := range bts { if _, ok := ctxs[bt]; !ok { ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ @@ -224,6 +256,16 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, C.ggml_set_name(tt, cname) slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + + size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) + if layer == -1 { + // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case + requiredMemory.InputWeights.Status = ml.Allocated + requiredMemory.InputWeights.Size += uint64(size) + } else { + btDeviceMemory[bt].Weights[layer].Size += uint64(size) + } + //nolint:staticcheck // TODO: check if buffer type supports this tensor return tt } @@ -245,22 +287,22 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, for _, t := range meta.Tensors().Items() { switch { case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" { - createTensor(tensor{source: t, target: "output.weight"}, output.bts) + createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks) } case contains(t.Name, "cls", "output", "output_norm"): - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): // TODO: assign vision tensors to the gpu if possible - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): // these tensors should be repeated per layer for i, layer := range layers { createTensor(tensor{ source: t, target: "blk." + strconv.Itoa(i) + "." + t.Name, - }, layer.bts) + }, layer.bts, i) } default: layerIndex := -1 @@ -271,10 +313,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } if layerIndex >= 0 { - createTensor(tensor{source: t}, layers[layerIndex].bts) + createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex) } else { // load all other tensors on the cpu - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) } } } @@ -287,8 +329,18 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + for i := range btDeviceMemory[bt].Weights { + if btDeviceMemory[bt].Weights[i].Size != 0 { + if b != nil { + btDeviceMemory[bt].Weights[i].Status = ml.Allocated + } else { + btDeviceMemory[bt].Weights[i].Status = ml.Failed + } + } + } + if b == nil { - return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt))) + panic(ml.ErrNoMem{BackendMemory: requiredMemory}) } C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) @@ -307,73 +359,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } } - var doneBytes atomic.Uint64 - totalBytes := uint64(n) - meta.Tensors().Offset - - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(runtime.GOMAXPROCS(0)) - for _, t := range meta.Tensors().Items() { - t := t - g.Go(func() error { - tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name]))) - for i := range tts { - target := targets[t.Name][i] - if target == "" { - target = t.Name - } - - tt, ok := tensors[target] - if !ok { - return fmt.Errorf("unassigned tensor: %s", t.Name) - } - - tts[i] = tt - } - - // Create a new FD for each goroutine so that each FD is read sequentially, rather than - // seeking around within an FD shared between all goroutines. - file, err := os.Open(r.Name()) - if err != nil { - slog.Warn("file open error", "file", r.Name(), "error", err) - return err - } - defer file.Close() - sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size())) - bts := make([]byte, 128*format.KibiByte) - - var s uint64 - for s < t.Size() { - // Stop if either the parent context has been canceled or if any of the other tensors returned an error - if err := ctx.Err(); err != nil { - return err - } - - n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))]) - if err != nil { - slog.Warn("file read error", "file", r.Name(), "error", err) - return err - } - - for _, tt := range tts { - C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n)) - } - - s += uint64(n) - - if params.Progress != nil { - done := doneBytes.Add(uint64(n)) - params.Progress(float32(done) / float32(totalBytes)) - } - } - - return nil - }) - } - - if err := g.Wait(); err != nil { - return nil, err - } - // map devices to backend buffer types so new tensors can be assigned to the correct device deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type) @@ -397,9 +382,11 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) return &Backend{ - flashAttention: params.FlashAttention, - meta: meta, - tensors: tensors, + modelPath: modelPath, + flashAttention: params.FlashAttention, + meta: meta, + tensorLoadTargets: targets, + tensors: tensors, sched: C.ggml_backend_sched_new( (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), @@ -418,7 +405,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, } return m }(), - maxGraphNodes: maxGraphNodes, + requiredMemory: &requiredMemory, + btDeviceMemory: btDeviceMemory, + maxGraphNodes: maxGraphNodes, }, nil } @@ -426,6 +415,81 @@ func init() { ml.RegisterBackend("ggml", New) } +func (b *Backend) Load(ctx context.Context, progress func(float32)) error { + var doneBytes atomic.Uint64 + totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) + for _, t := range b.meta.Tensors().Items() { + t := t + g.Go(func() error { + tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name]))) + for i := range tts { + target := b.tensorLoadTargets[t.Name][i] + if target == "" { + target = t.Name + } + + tt, ok := b.tensors[target] + if !ok { + return fmt.Errorf("unassigned tensor: %s", t.Name) + } + + tts[i] = tt + } + + // Create a new FD for each goroutine so that each FD is read sequentially, rather than + // seeking around within an FD shared between all goroutines. + file, err := os.Open(b.modelPath) + if err != nil { + slog.Warn("file open error", "file", b.modelPath, "error", err) + return err + } + defer file.Close() + sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size())) + bts := make([]byte, 128*format.KibiByte) + + var s uint64 + for s < t.Size() { + // Stop if either the parent context has been canceled or if any of the other tensors returned an error + if err := ctx.Err(); err != nil { + return err + } + + n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))]) + if err != nil { + slog.Warn("file read error", "file", b.modelPath, "error", err) + return err + } + + for _, tt := range tts { + C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n)) + } + + s += uint64(n) + + if progress != nil { + done := doneBytes.Add(uint64(n)) + progress(float32(done) / float32(totalBytes)) + } + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return err + } + + return nil +} + +func (b *Backend) BackendMemory() ml.BackendMemory { + return *b.requiredMemory +} + func (b *Backend) Config() fs.Config { return b.meta.KV() } @@ -457,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { no_alloc: true, }), allocatedBuffers: &allocatedBuffers, + layer: -1, } } @@ -483,6 +548,9 @@ type Context struct { // maxGraphNodes is the maximum allowed number of graph nodes in this context maxGraphNodes int + + // layer is the graph layer that this context is allocating for - assumed to be cache + layer int } func (c *Context) Input() ml.Context { @@ -493,6 +561,7 @@ func (c *Context) Input() ml.Context { buft: c.b.input, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: -1, } } @@ -507,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context { buft: buft, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: i, } } @@ -544,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } -func (c *Context) Reserve() error { - if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { - C.ggml_backend_sched_reset(c.b.sched) - return errors.New("failed to reserve graph") - } +func (c *Context) Reserve() { + reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph) slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) - for i := range c.b.schedBackends { - size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), - "size", format.HumanBytes2(uint64(size))) + + // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations + for _, bt := range c.b.schedBufts { + c.b.btDeviceMemory[bt].Graph = ml.Memory{} } - C.ggml_backend_sched_reset(c.b.sched) + for i := range c.b.schedBackends { + bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) - return nil + graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph + graph.Size += uint64(bufferStatus.size) + if bufferStatus.allocated && graph.Status != ml.Failed { + graph.Status = ml.Allocated + } else { + graph.Status = ml.Failed + } + + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), + "size", format.HumanBytes2(uint64(bufferStatus.size))) + } + + if !reserved { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } } func (c *Context) MaxGraphNodes() int { @@ -579,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { +func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { if c.buft == nil { panic("set Input or Layer before creating tensors") } @@ -602,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { if len(shape) < 1 || shape[0] == 0 { var shape C.int64_t = 0 - return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil + return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} } else if len(shape) > 4 { panic("unsupported number of dimensions") } @@ -615,40 +697,43 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) - b := C.ggml_backend_buft_alloc_buffer(c.buft, size) - if b == nil { - return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft))) - } - *c.allocatedBuffers = append(*c.allocatedBuffers, b) + b := C.ggml_backend_buft_alloc_buffer(c.buft, size) + if c.layer >= 0 { + cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] + + cache.Size += uint64(size) + if b != nil { + cache.Status = ml.Allocated + } else { + cache.Status = ml.Failed + } + } + + if b == nil { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } + + *c.allocatedBuffers = append(*c.allocatedBuffers, b) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - return &Tensor{b: c.b, t: t}, nil + return &Tensor{b: c.b, t: t} } func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - - return t + return c.newTensor(dtype, shape) } func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - + t := c.newTensor(dtype, shape) C.ggml_set_zero(t.(*Tensor).t) return t } -func checkShape[S ~[]E, E any](s S, shape ...int) error { +func checkShape[S ~[]E, E any](s S, shape ...int) { n := len(s) if n == 0 { - return nil + return } for _, v := range shape { @@ -656,44 +741,32 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error { } if n != 1 { - return fmt.Errorf("invalid shape: %v", shape) + panic(fmt.Errorf("invalid shape: %v", shape)) } - - return nil } -func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor { + checkShape(s, shape...) - t, err := c.newTensor(ml.DTypeF32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeF32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } -func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor { + checkShape(s, shape...) - t, err := c.newTensor(ml.DTypeI32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeI32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -711,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { arange = append(arange, int32(i)) } - t, err := c.Input().FromIntSlice(arange, len(arange)) - if err != nil { - panic(err) - } - - return t + return c.Input().FromIntSlice(arange, len(arange)) default: panic("unsupported dtype for arange") } @@ -867,6 +935,13 @@ func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + } +} + func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -984,6 +1059,13 @@ func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { } } +func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -1055,28 +1137,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -const ( - ropeTypeNorm C.int = 0 - ropeTypeNeox C.int = 2 - ropeTypeMrope C.int = 8 - ropeTypeVision C.int = 24 -) - -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor { +func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor { // Default options - opts := &ml.RopeOptions{ - OriginalContextLen: 131072, - } + opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}} // Apply any provided options for _, option := range options { option(opts) } - if ropeFactors == nil { - ropeFactors = &Tensor{b: t.b} - } - dequant := t.t if C.ggml_is_quantized(t.t._type) { dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) @@ -1087,11 +1156,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, - positionIDs.(*Tensor).t, - ropeFactors.(*Tensor).t, + positions.(*Tensor).t, + opts.Factors.(*Tensor).t, C.int(ropeDim), - C.int(ropeType), - C.int(opts.OriginalContextLen), + C.int(opts.Type), + C.int(opts.OriginalContextLength), C.float(ropeBase), C.float(ropeScale), C.float(0.0), diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h index 2cb150fd..781b1e10 100644 --- a/ml/backend/ggml/ggml/include/ggml-alloc.h +++ b/ml/backend/ggml/ggml/include/ggml-alloc.h @@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); +struct ggml_allocr_buffer_status { + size_t size; + bool allocated; +}; +GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); + // Utils // Create a buffer and allocate all the tensors in a ggml_context GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index 778927f6..74e46716 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -304,6 +304,12 @@ extern "C" { GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + struct ggml_backend_buffer_status { + size_t size; + bool allocated; + }; + GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 5fd379f6..04812990 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -364,6 +364,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] ggml_backend_buffer_t * buffers; // [n_buffers] + size_t *buffer_sizes; // [n_buffers] struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; @@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); GGML_ASSERT(galloc->buffers != NULL); + galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); + GGML_ASSERT(galloc->buffer_sizes != NULL); + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); @@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); + free(galloc->buffer_sizes); free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); @@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } + bool success = true; + // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer @@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); - if (galloc->buffers[i] == NULL) { + if (galloc->buffers[i]) { + galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); + ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } else { GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); - return false; + galloc->buffer_sizes[i] = new_size; + success = false; } - ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } else { + galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); } } - return true; + return success; } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { @@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } +struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); + + for (int i = 0; i < buffer_id; i++) { + if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) { + // This buffer is the same as a previous one due to the same buffer type being used multiple times + // (See above.) However, we need a different check because multiple buffers might be NULL in our + // case and we still want to know the attempted size. + + struct ggml_allocr_buffer_status status = {0, true}; + return status; + } + } + + struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; + return status; +} + // utils static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 0ce73a99..be335e8c 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } +struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); + struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; + + return status; +} + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go b/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go index 895d093c..895b7f6e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go +++ b/ml/backend/ggml/ggml/src/ggml-cpu/cpu.go @@ -3,7 +3,7 @@ package cpu // #cgo CFLAGS: -O3 -Wno-implicit-function-declaration // #cgo CXXFLAGS: -std=c++17 // #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include -// #cgo CPPFLAGS: -DGGML_USE_LLAMAFILE +// #cgo CPPFLAGS: -DNDEBUG -DGGML_USE_LLAMAFILE // #cgo linux CPPFLAGS: -D_GNU_SOURCE // #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 // #cgo darwin,arm64 LDFLAGS: -framework Accelerate diff --git a/ml/backend/ggml/ggml/src/ggml-metal/metal.go b/ml/backend/ggml/ggml/src/ggml-metal/metal.go index eb65dfde..0ee017dd 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/metal.go +++ b/ml/backend/ggml/ggml/src/ggml-metal/metal.go @@ -4,6 +4,6 @@ package metal //go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal" -// #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include +// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include // #cgo LDFLAGS: -framework Metal -framework MetalKit import "C" diff --git a/ml/nn/fast/rope.go b/ml/nn/fast/rope.go new file mode 100644 index 00000000..b45938eb --- /dev/null +++ b/ml/nn/fast/rope.go @@ -0,0 +1,21 @@ +// fast provides implementations of fast (fused) operations for increased performance. +package fast + +import ( + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/rope" +) + +// fastRoPE is an interface for tensors that support fast rotary positional embedding. +type fastRoPE interface { + RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor +} + +// RoPE applies rotary positional embedding to tensor `t`. +func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor { + if t, ok := t.(fastRoPE); ok { + return t.RoPE(ctx, positions, dim, base, scale, options...) + } + + panic("RoPE not implemented for this tensor type") +} diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/rope.go new file mode 100644 index 00000000..b0c00a5b --- /dev/null +++ b/ml/nn/rope/rope.go @@ -0,0 +1,33 @@ +package rope + +import "github.com/ollama/ollama/ml" + +// Options contains optional parameters for RoPE function +type Options struct { + OriginalContextLength int + Type int + Factors ml.Tensor +} + +// WithOriginalContextLength sets a custom context length +func WithOriginalContextLength(n int) func(*Options) { + return func(opts *Options) { + opts.OriginalContextLength = n + } +} + +// WithType sets RoPE type to NeoX +func WithTypeNeoX() func(*Options) { + return func(opts *Options) { + opts.Type = 2 + } +} + +// WithFactors sets custom rope factors +func WithFactors(factors ml.Tensor) func(*Options) { + return func(opts *Options) { + if factors != nil { + opts.Factors = factors + } + } +} diff --git a/model/process_text.go b/model/bytepairencoding.go similarity index 66% rename from model/process_text.go rename to model/bytepairencoding.go index 7b87ccc3..6bb9a003 100644 --- a/model/process_text.go +++ b/model/bytepairencoding.go @@ -5,116 +5,13 @@ import ( "context" "iter" "log/slog" - "slices" "strings" - "sync" "github.com/dlclark/regexp2" heap "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/ollama/ollama/logutil" ) -type Special int32 - -const ( - SpecialBOS Special = iota - SpecialEOS -) - -const ( - TOKEN_TYPE_NORMAL = iota + 1 - TOKEN_TYPE_UNKNOWN - TOKEN_TYPE_CONTROL - TOKEN_TYPE_USER_DEFINED - TOKEN_TYPE_UNUSED - TOKEN_TYPE_BYTE -) - -type TextProcessor interface { - Encode(s string, addSpecial bool) ([]int32, error) - Decode([]int32) (string, error) - Is(int32, Special) bool - Vocabulary() *Vocabulary -} - -type Vocabulary struct { - Values []string - Types []int32 - Scores []float32 - Merges []string - - BOS, EOS, EOT int32 - AddBOS, AddEOS, AddEOT bool - - specialOnce sync.Once - special []string - - valuesOnce sync.Once - values map[string]int32 - - mergeOnce sync.Once - merge map[string]int32 -} - -func (v *Vocabulary) Is(id int32, special Special) bool { - switch special { - case SpecialBOS: - return id == v.BOS - case SpecialEOS: - return id == v.EOS || id == v.EOT - default: - return false - } -} - -func (v *Vocabulary) Encode(s string) int32 { - v.valuesOnce.Do(func() { - v.values = make(map[string]int32, len(v.Values)) - for i, value := range v.Values { - v.values[value] = int32(i) - } - }) - - if id, ok := v.values[s]; ok { - return id - } - - return -1 -} - -func (v *Vocabulary) Decode(id int32) string { - return v.Values[id] -} - -func (v *Vocabulary) SpecialVocabulary() []string { - v.specialOnce.Do(func() { - for i := range v.Values { - if slices.Contains([]int{105, 106}, i) { - v.special = append(v.special, v.Values[i]) - } else if v.Types[i] == TOKEN_TYPE_CONTROL { - v.special = append(v.special, v.Values[i]) - } - } - }) - - return v.special -} - -func (v *Vocabulary) Merge(left, right string) int { - v.mergeOnce.Do(func() { - v.merge = make(map[string]int32, len(v.Merges)) - for i, merge := range v.Merges { - v.merge[merge] = int32(i) - } - }) - - if id, ok := v.merge[left+" "+right]; ok { - return int(id) - } - - return -1 -} - type BytePairEncoding struct { pre *regexp2.Regexp vocab *Vocabulary @@ -304,27 +201,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if bpe.vocab.AddBOS { - if ids[0] == bpe.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS) - ids = append([]int32{bpe.vocab.BOS}, ids...) - } - - if bpe.vocab.AddEOS { - if ids[len(ids)-1] == bpe.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS) - ids = append(ids, bpe.vocab.EOS) - } + ids = bpe.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -352,6 +234,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_test.go b/model/bytepairencoding_test.go similarity index 100% rename from model/process_text_test.go rename to model/bytepairencoding_test.go diff --git a/model/model.go b/model/model.go index 98381c90..25097e01 100644 --- a/model/model.go +++ b/model/model.go @@ -98,14 +98,8 @@ func Register(name string, f func(fs.Config) (Model, error)) { } // New initializes a new model instance with the provided configuration based on the metadata in the model file -func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) { - r, err := os.Open(modelPath) - if err != nil { - return nil, err - } - defer r.Close() - - b, err := ml.NewBackend(ctx, r, params) +func New(modelPath string, params ml.BackendParams) (Model, error) { + b, err := ml.NewBackend(modelPath, params) if err != nil { return nil, err } @@ -134,7 +128,7 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() - meta, _, err := fsggml.Decode(r, -1) + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } @@ -293,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten return nil, errors.New("batch size cannot be less than 1") } - var err error - batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) - if err != nil { - return nil, err - } + batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs)) cache := m.Config().Cache if cache != nil { diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 3156b006..e621d03a 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) @@ -43,10 +45,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), @@ -80,11 +85,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -94,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -124,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { @@ -171,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index d53eb6cc..53bf8275 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -60,12 +60,16 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(1), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(106), - AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), @@ -97,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, + pixelValues := ctx.Input().FromFloatSlice(f32s, m.ImageProcessor.imageSize, m.ImageProcessor.imageSize, m.ImageProcessor.numChannels, ) - if err != nil { - return nil, err - } visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) @@ -140,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index a40614af..70d7797e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -73,7 +75,6 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) ropeBase := opts.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { @@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index c75d7eb2..3cf782d0 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -1,22 +1,23 @@ package llama import ( - "fmt" + "cmp" "math" - "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type Options struct { hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int eps, ropeBase, ropeScale float32 - ropeDim uint32 } type Model struct { @@ -32,10 +33,6 @@ type Model struct { } func New(c fs.Config) (model.Model, error) { - if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { - return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) - } - m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), @@ -43,13 +40,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), @@ -57,10 +54,11 @@ func New(c fs.Config) (model.Model, error) { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } @@ -77,31 +75,31 @@ type SelfAttention struct { RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) - q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query := sa.Query.Forward(ctx, hiddenState) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key := sa.Key.Forward(ctx, hiddenState) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - return sa.Output.Forward(ctx, kqv) + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, attention) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -122,11 +120,11 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts) // In the final layer (outputs != nil), optimize by pruning to just the token positions // we need logits for. @@ -144,27 +142,19 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) - var lastLayerOutputs ml.Tensor + var outputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index c94aa72f..8084760b 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -40,13 +40,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), @@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) - if err != nil { - return nil, err - } + tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize @@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input pixelValues := tilesLocal if len(pixelsGlobal) > 0 { - tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) - if err != nil { - return nil, err - } - + tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) } @@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index d98587bd..27935f40 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) - key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -80,7 +82,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { - nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) + nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) } return nextStates @@ -221,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) } - var err error - attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) - if err != nil { - panic(err) - } + attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) } for i, layer := range m.Layers { @@ -250,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index e6b1afef..dc6f82b8 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { } } - ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) - if err != nil { - panic(err) - } + ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index b93882a9..9d662fc1 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -31,31 +31,26 @@ var _ model.MultimodalProcessor = (*Model)(nil) var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { - textModel, err := NewTextModel(c) - if err != nil { - return nil, err - } - m := &Model{ - TextModel: textModel, - VisionModel: newVisionModel(c), - ImageProcessor: newImageProcessor(c), - MultiModalProjector: newMultiModalProjector(c), BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), + TextModel: newTextModel(c), + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), } m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) @@ -119,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) visionOutputs := m.VisionModel.Forward(ctx, pixelValues) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) @@ -166,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 17939800..19c36f9f 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -1,21 +1,21 @@ package mistral3 import ( - "fmt" + "cmp" "math" - "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - hiddenSize, numHeads, numKVHeads, headDim int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -36,19 +36,15 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(0) - headDim := opts.headDim - if headDim == 0 { - headDim = opts.hiddenSize / opts.numHeads - } + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -59,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -125,24 +121,18 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor return m.Output.Forward(ctx, hiddenState) } -func NewTextModel(c fs.Config) (*TextModel, error) { - if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { - return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) - } - - textModel := &TextModel{ +func newTextModel(c fs.Config) *TextModel { + return &TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } - - return textModel, nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 469dc40c..65bdcff2 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) } } - h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } - - w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } + h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) + w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) @@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { } } - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - panic(err) - } + positionIDs := ctx.Input().FromIntSlice(positions, len(positions)) positionEmbedding := m.positionalEmbedding(ctx, positionIDs) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) @@ -170,7 +160,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { func newVisionModel(c fs.Config) *VisionModel { return &VisionModel{ - Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), VisionModelOptions: &VisionModelOptions{ hiddenSize: int(c.Uint("vision.embedding_length", 1024)), numHeads: int(c.Uint("vision.attention.head_count", 16)), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 15571d9c..45cb3e02 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -38,13 +38,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), @@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) - if err != nil { - return nil, err - } - - aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) + aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) @@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 9bd414af..47a518ce 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" ) type TextSelfAttention struct { @@ -21,15 +23,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int + ropeDim int eps, ropeBase, ropeScale float32 - ropeDim uint32 crossAttentionLayers []int32 } @@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 77ea5373..2d424947 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -16,8 +16,6 @@ type VisionSelfAttention struct { Key *nn.Linear `gguf:"attn_k"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` - - Gate ml.Tensor `gguf:"attn_gate"` } func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { @@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) - - hiddenState = sa.Output.Forward(ctx, attention) - return hiddenState + return sa.Output.Forward(ctx, attention) } type VisionMLP struct { @@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts // self attention hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts) - if e.AttentionGate != nil { hiddenState = hiddenState.Mul(ctx, e.AttentionGate) } hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState - // feed forward hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.MLP.Forward(ctx, hiddenState, opts) - hiddenState = hiddenState.Add(ctx, residual) if e.MLPGate != nil { hiddenState = hiddenState.Mul(ctx, e.MLPGate) } - + hiddenState = hiddenState.Add(ctx, residual) return hiddenState } diff --git a/model/models/models.go b/model/models/models.go index 133e5176..5471ce89 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,5 +7,7 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" + _ "github.com/ollama/ollama/model/models/qwen3" ) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go new file mode 100644 index 00000000..42338d0d --- /dev/null +++ b/model/models/qwen2/model.go @@ -0,0 +1,164 @@ +package qwen2 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) + + query := attn.Query.Forward(ctx, hiddenStates) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key := attn.Key.Forward(ctx, hiddenStates) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + value := attn.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + + return attn.Output.Forward(ctx, attention) +} + +type MLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type DecoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.MLP.Forward(ctx, hiddenStates) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []DecoderLayer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +// Forward implements model.Model. +func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + hiddenStates = m.Output.Forward(ctx, hiddenStates) + return hiddenStates, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil +} + +func New(c fs.Config) (model.Model, error) { + m := Model{ + Layers: make([]DecoderLayer, c.Uint("block_count")), + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + eps: c.Float("attention.layer_norm_rms_epsilon"), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func init() { + model.Register("qwen2", New) +} diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 48655450..ee38cad9 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -34,12 +34,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), TextModel: NewTextModel(c), @@ -68,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * m.ImageProcessor.patchSize * m.ImageProcessor.patchSize numPatches := grid.Temporal * grid.Height * grid.Width - pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) - if err != nil { - return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err) - } + pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) return pixelValues, grid, nil } @@ -120,13 +118,14 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 1}) + result = append(result, input.Input{Token: visionStartToken}) // Add the image token with the multimodal tensor data at the first position result = append(result, input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, + SameBatch: patchesPerChunk, }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) @@ -140,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 800fd961..4b6bc166 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -7,13 +7,15 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - ctxLen, hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim, defaultContextLen uint32 + hiddenSize, numHeads, numKVHeads int + ropeDim, originalContextLength int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel { m := TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ - ctxLen: int(c.Uint("context_length")), - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count", 128), - defaultContextLen: c.Uint("context_length", 128000), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count", 128)), + originalContextLength: int(c.Uint("context_length", 128000)), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), }, } @@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 01eef392..4d7afaa1 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -1,7 +1,6 @@ package qwen25vl import ( - "fmt" "math" "slices" @@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int } } - mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) - if err != nil { - panic(err) - } + mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + // Reshape to match [seqLength, seqLength, 1] for broadcasting mask = mask.Reshape(ctx, seqLength, seqLength, 1) @@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) } } - t, err := ctx.Input().FromIntSlice(index, len(index)) - if err != nil { - panic(err) - } + t := ctx.Input().FromIntSlice(index, len(index)) return t, bounds } @@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) } } - freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) - if err != nil { - panic(fmt.Errorf("failed to create tensor from frequencies: %w", err)) - } + freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) // Create position coordinates (y,x pairs) for the grid // In PyTorch: Equivalent to generating position ids with torch.arange() @@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor coords = append(coords, int32(y), int32(x)) } } - pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) - if err != nil { - panic(fmt.Errorf("failed to create tensor from positions: %w", err)) - } + pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) // Reshape and permute positions to match spatial merging pattern pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go new file mode 100644 index 00000000..1930da7e --- /dev/null +++ b/model/models/qwen3/model.go @@ -0,0 +1,233 @@ +package qwen3 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + eps float32 + ropeBase, ropeScale float32 + + keyLength, valueLength int + + numExperts, numExpertsUsed int + normTopKProb bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +type Attention struct { + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Query *nn.Linear `gguf:"attn_q"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` + Up ml.Tensor `gguf:"ffn_up_exps.weight"` + Down ml.Tensor `gguf:"ffn_down_exps.weight"` +} + +func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + routerLogits := mlp.Router.Forward(ctx, hiddenStates) + + routingWeights := routerLogits.Softmax(ctx) + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx)) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts) + + hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = hiddenStates.SILU(ctx) + hiddenStates = hiddenStates.Mul(ctx, upStates) + + experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP +} + +func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Layers []Layer `gguf:"blk"` + + *Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil +} + +var _ model.Model = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + for i := range layers { + if c.String("general.architecture") == "qwen3moe" { + layers[i].MLP = &sparse{} + } else { + layers[i].MLP = &dense{} + } + } + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func init() { + model.Register("qwen3", New) + model.Register("qwen3moe", New) +} diff --git a/model/process_text_spm.go b/model/sentencepiece.go similarity index 89% rename from model/process_text_spm.go rename to model/sentencepiece.go index b1cff7d2..7d725f04 100644 --- a/model/process_text_spm.go +++ b/model/sentencepiece.go @@ -182,27 +182,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if spm.vocab.AddBOS { - if ids[0] == spm.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS) - ids = append([]int32{spm.vocab.BOS}, ids...) - } - - if spm.vocab.AddEOS { - if ids[len(ids)-1] == spm.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS) - ids = append(ids, spm.vocab.EOS) - } + ids = spm.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -261,6 +246,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_spm_test.go b/model/sentencepiece_test.go similarity index 100% rename from model/process_text_spm_test.go rename to model/sentencepiece_test.go diff --git a/model/textprocessor.go b/model/textprocessor.go new file mode 100644 index 00000000..4a36f235 --- /dev/null +++ b/model/textprocessor.go @@ -0,0 +1,17 @@ +package model + +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + +type TextProcessor interface { + Encode(s string, addSpecial bool) ([]int32, error) + Decode([]int32) (string, error) + Is(int32, Special) bool + Vocabulary() *Vocabulary +} diff --git a/model/vocabulary.go b/model/vocabulary.go new file mode 100644 index 00000000..24adbaca --- /dev/null +++ b/model/vocabulary.go @@ -0,0 +1,112 @@ +package model + +import ( + "log/slog" + "slices" + "sync" +) + +type Special int32 + +const ( + SpecialBOS Special = iota + SpecialEOS +) + +type Vocabulary struct { + Values []string + Types []int32 + Scores []float32 + Merges []string + + BOS, EOS []int32 + AddBOS, AddEOS bool + + specialOnce sync.Once + special []string + + valuesOnce sync.Once + values map[string]int32 + + mergeOnce sync.Once + merge map[string]int32 +} + +func (v *Vocabulary) Is(id int32, special Special) bool { + switch special { + case SpecialBOS: + return slices.Contains(v.BOS, id) + case SpecialEOS: + return slices.Contains(v.EOS, id) + default: + return false + } +} + +func (v *Vocabulary) addSpecials(ids []int32) []int32 { + if v.AddBOS && len(v.BOS) > 0 { + if slices.Contains(v.BOS, ids[0]) { + slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) + } + + slog.Debug("adding bos token to prompt", "id", v.BOS) + ids = append([]int32{v.BOS[0]}, ids...) + } + + if v.AddEOS && len(v.EOS) > 0 { + if slices.Contains(v.BOS, ids[len(ids)-1]) { + slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) + } + + slog.Debug("adding eos token to prompt", "id", v.EOS) + ids = append(ids, v.EOS[0]) + } + + return ids +} + +func (v *Vocabulary) Encode(s string) int32 { + v.valuesOnce.Do(func() { + v.values = make(map[string]int32, len(v.Values)) + for i, value := range v.Values { + v.values[value] = int32(i) + } + }) + + if id, ok := v.values[s]; ok { + return id + } + + return -1 +} + +func (v *Vocabulary) Decode(id int32) string { + return v.Values[id] +} + +func (v *Vocabulary) SpecialVocabulary() []string { + v.specialOnce.Do(func() { + for i := range v.Values { + if v.Types[i] == TOKEN_TYPE_CONTROL { + v.special = append(v.special, v.Values[i]) + } + } + }) + + return v.special +} + +func (v *Vocabulary) Merge(left, right string) int { + v.mergeOnce.Do(func() { + v.merge = make(map[string]int32, len(v.Merges)) + for i, merge := range v.Merges { + v.merge[merge] = int32(i) + } + }) + + if id, ok := v.merge[left+" "+right]; ok { + return int(id) + } + + return -1 +} diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index d78612fe..fbdc7d72 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -95,17 +95,14 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten } } } else { - err := computeCtx.Reserve() - if err != nil { - return nil, err - } + computeCtx.Reserve() } } for i, t := range entry.mm { if in == t.Tensor { if !reserve { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil } else { return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index cd42d434..a7a889f1 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Outputs[i] = int32(i) } - batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) - if err != nil { - return err - } + batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) cache := s.model.Config().Cache if cache != nil { @@ -826,16 +823,12 @@ func (s *Server) reserveWorstCaseGraph() error { return err } - err = ctx.Forward(t).Reserve() - if err != nil { - return err - } + ctx.Forward(t).Reserve() return nil } -func (s *Server) loadModel( - ctx context.Context, +func (s *Server) initModel( mpath string, params ml.BackendParams, lpath multiLPath, @@ -843,21 +836,21 @@ func (s *Server) loadModel( kvCacheType string, kvSize int, multiUserCache bool, -) { +) error { var err error - s.model, err = model.New(ctx, mpath, params) + s.model, err = model.New(mpath, params) if err != nil { - panic(err) + return err } // TODO(jessegross): LoRA loading if lpath.String() != "" { - panic("loras are not yet implemented") + return errors.New("loras are not yet implemented") } s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { - panic(err) + return err } if !s.cache.enabled && parallel > 1 { @@ -869,7 +862,30 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - err = s.reserveWorstCaseGraph() + return s.reserveWorstCaseGraph() +} + +func (s *Server) load( + ctx context.Context, + mpath string, + params ml.BackendParams, + lpath multiLPath, + parallel int, + kvCacheType string, + kvSize int, + multiUserCache bool, +) { + err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) + if err != nil { + panic(err) + } + + slog.Debug("memory", "allocated", s.model.Backend().BackendMemory()) + + err = s.model.Backend().Load(ctx, + func(progress float32) { + s.progress = progress + }) if err != nil { panic(err) } @@ -913,9 +929,14 @@ func Execute(args []string) error { status: llm.ServerStatusLoadingModel, } + server.cond = sync.NewCond(&server.mu) + server.ready.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // TODO(jessegross): Parameters that need to be implemented: // no-mmap - // mlock var tensorSplitFloats []float32 if *tensorSplit != "" { @@ -928,9 +949,6 @@ func Execute(args []string) error { } params := ml.BackendParams{ - Progress: func(progress float32) { - server.progress = progress - }, NumThreads: *threads, NumGPULayers: *numGPULayers, MainGPU: *mainGPU, @@ -938,14 +956,7 @@ func Execute(args []string) error { FlashAttention: *flashAttention, } - server.ready.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) - - server.cond = sync.NewCond(&server.mu) - + go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) go server.run(ctx) addr := "127.0.0.1:" + strconv.Itoa(*port) diff --git a/sample/samplers.go b/sample/samplers.go index f0846c8d..d395650d 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa vocabIds[i] = uint32(i) } - grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)}) + grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS) if grammar == nil { return nil, errors.New("sample: failed to initialize grammar") } diff --git a/server/create.go b/server/create.go index 68e003df..bd970876 100644 --- a/server/create.go +++ b/server/create.go @@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is } defer bin.Close() - f, _, err := ggml.Decode(bin, -1) + f, err := ggml.Decode(bin, -1) if err != nil { return nil, err } @@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr return nil, err } - f, _, err := ggml.Decode(temp, 1024) + f, err := ggml.Decode(temp, 1024) if err != nil { slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) return nil, err @@ -501,48 +501,27 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML return nil, errOnlyGGUFSupported } - stat, err := blob.Stat() + f, err := ggml.Decode(blob, -1) if err != nil { return nil, err } - var offset int64 - for offset < stat.Size() { - f, n, err := ggml.Decode(blob, 1024) - if errors.Is(err, io.EOF) { - break - } else if err != nil { - return nil, err - } - - mediatype := "application/vnd.ollama.image.model" - if f.KV().Kind() == "adapter" { - mediatype = "application/vnd.ollama.image.adapter" - } else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" { - mediatype = "application/vnd.ollama.image.projector" - } - - var layer Layer - if digest != "" && n == stat.Size() && offset == 0 { - layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) - if err != nil { - slog.Debug("could not create new layer from layer", "error", err) - return nil, err - } - } - - // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) - if layer.Digest == "" { - layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype) - if err != nil { - return nil, err - } - } - - layers = append(layers, &layerGGML{layer, f}) - offset = n + mediatype := "application/vnd.ollama.image.model" + if f.KV().Kind() == "adapter" { + mediatype = "application/vnd.ollama.image.adapter" + } else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" { + // if a model has vision.block_count but not block_count, it is a standalone vision model + mediatype = "application/vnd.ollama.image.projector" } + layer, err := NewLayerFromLayer(digest, mediatype, blob.Name()) + if err != nil { + slog.Debug("could not create new layer from layer", "error", err) + return nil, err + } + + layers = append(layers, &layerGGML{layer, f}) + return detectChatTemplate(layers) } diff --git a/server/images.go b/server/images.go index 352f10f2..a69e2a9f 100644 --- a/server/images.go +++ b/server/images.go @@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability { if err == nil { defer r.Close() - f, _, err := ggml.Decode(r, 1024) + f, err := ggml.Decode(r, 1024) if err == nil { if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { capabilities = append(capabilities, model.CapabilityEmbedding) diff --git a/server/model.go b/server/model.go index 2149ff85..401547e4 100644 --- a/server/model.go +++ b/server/model.go @@ -10,9 +10,6 @@ import ( "log/slog" "net/http" "os" - "slices" - "strings" - "text/template/parse" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" @@ -64,7 +61,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe } defer blob.Close() - f, _, err := ggml.Decode(blob, -1) + f, err := ggml.Decode(blob, -1) if err != nil { return nil, err } @@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } - -func parseObjects(s string) []map[string]any { - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { - // skip over any syntax errors - offset += int(syntax.Offset) - } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { - // skip over any unmarshalable types - offset += int(unmarshalType.Offset) - } else if err != nil { - return nil - } else { - offset += int(decoder.InputOffset()) - objs = append(objs, obj) - } - } - - return objs -} - -// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// mxyng: this only really works if the input contains tool calls in some JSON format -func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { - // create a subtree from the node that ranges over .ToolCalls - tmpl := m.Template.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl == nil { - return nil, false - } - - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return nil, false - } - - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false - } - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - } - - return all - } - - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) - } - - var toolCalls []api.ToolCall - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } - } - - return toolCalls, len(toolCalls) > 0 -} diff --git a/server/model_test.go b/server/model_test.go deleted file mode 100644 index e5c2f2bb..00000000 --- a/server/model_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" -) - -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) - if err != nil { - t.Fatal(err) - } - - return bytes.NewBuffer(bts) -} - -func TestExecuteWithTools(t *testing.T) { - p := filepath.Join("testdata", "tools") - cases := []struct { - model string - output string - ok bool - }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, - {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"command-r-plus", "Action: ```json" + ` -[ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } -] -` + "```", true}, - {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"llama3-groq-tool-use", ` -{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} -`, true}, - {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, - {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, - } - - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) - } - - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) - } - - calls := []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, - }, - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, - }, - } - - for _, tt := range cases { - t.Run(tt.model, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) - if err != nil { - t.Fatal(err) - } - - t.Run("template", func(t *testing.T) { - var actual bytes.Buffer - if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - m := &Model{Template: tmpl} - actual, ok := m.parseToolCalls(tt.output) - if ok != tt.ok { - t.Fatalf("expected %t, got %t", tt.ok, ok) - } - - if tt.ok { - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } - }) - }) - } -} - -func TestParseObjects(t *testing.T) { - tests := []struct { - input string - want []map[string]any - }{ - { - input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": `, - want: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := parseObjects(tc.input) - - if diff := cmp.Diff(got, tc.want); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - } -} diff --git a/server/quantization.go b/server/quantization.go index adfc948e..e57e8a4d 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -120,14 +120,30 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType if newType.IsQuantized() { nx := shape[0] - ny := uint64(1) - if len(shape) > 1 { - ny = shape[1] - } qk_k := newType.BlockSize() + + // Check if first dimension is divisible by block size if nx%qk_k != 0 { - slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String())) - newType = fsggml.TensorTypeF16 + // Store the original type for logging + originalType := newType + + // Select appropriate fallback based on original type + switch newType { + case fsggml.TensorTypeQ4_K: + newType = fsggml.TensorTypeQ5_0 + case fsggml.TensorTypeQ5_K: + newType = fsggml.TensorTypeQ5_1 + case fsggml.TensorTypeQ6_K: + newType = fsggml.TensorTypeQ8_0 + } + + // Final check - if still incompatible, fall back to F16 + if nx%newType.BlockSize() != 0 { + newType = fsggml.TensorTypeF16 + } + + slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s", + nx, qk_k, originalType.String(), newType.String())) } } return newType diff --git a/server/quantization_test.go b/server/quantization_test.go index 495297df..4f717c2c 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) { t.Fatal(err.Error()) } defer fp.Close() - meta, _, err := fsggml.Decode(fp, -1) + meta, err := fsggml.Decode(fp, -1) if err != nil { t.Fatal(err.Error()) } @@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) { t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) } defer fpNew.Close() - newMeta, _, err := fsggml.Decode(fpNew, -1) + newMeta, err := fsggml.Decode(fpNew, -1) if err != nil { t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) } diff --git a/server/routes.go b/server/routes.go index d0b8f487..42e8cdd1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/tools" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -1482,11 +1483,20 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + var toolParser *tools.Parser + if len(req.Tools) > 0 { + toolParser, err = tools.NewParser(m.Template.Template) + if err != nil { + slog.Error("failed to create tool parser", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + ch := make(chan any) go func() { defer close(ch) - var sb strings.Builder - var toolCallIndex int = 0 + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1512,37 +1522,21 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO: tool call checking and filtering should be moved outside of this callback once streaming - // however this was a simple change for now without reworking streaming logic of this (and other) - // handlers - if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { - ch <- res - return - } - - // Streaming tool calls: - // If tools are recognized, use a flag to track the sending of a tool downstream - // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ + if len(req.Tools) > 0 { + toolCalls, content := toolParser.Add(r.Content) + if len(content) > 0 { + res.Message.Content = content + } else if len(toolCalls) > 0 { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + } else { + if r.Done { + ch <- res + } + return } - res.Message.Content = "" - sb.Reset() - ch <- res - return - } - - if r.Done { - // Send any remaining content if no tool calls were detected - if toolCallIndex == 0 { - res.Message.Content = sb.String() - } - ch <- res } + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1551,11 +1545,15 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var sb strings.Builder + var toolCalls []api.ToolCall for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + if len(req.Tools) > 0 { + toolCalls = append(toolCalls, t.Message.ToolCalls...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1571,12 +1569,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls } c.JSON(http.StatusOK, resp) diff --git a/server/sched.go b/server/sched.go index 3fc54e55..612e4702 100644 --- a/server/sched.go +++ b/server/sched.go @@ -387,6 +387,17 @@ func (s *Scheduler) processCompleted(ctx context.Context) { s.loadedMu.Unlock() runner.refMu.Unlock() slog.Debug("duplicate expired event, ignoring", "runner", runner) + } else if runner.pid != runnerToUnload.pid { + // If the pids do not match, we likely had multiple load + // failures for the same model in quick succession due to + // request context canceled and are draining the queue of + // events. Ensure the orphaned runner is properly shut down, but + // do not delete the mismatched loaded runner, or wait for VRAM + // convergence. + slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload) + runner.unload() + s.loadedMu.Unlock() + runner.refMu.Unlock() } else { slog.Debug("starting background wait for VRAM recovery", "runner", runner) finished := runner.waitForVRAMRecovery() diff --git a/server/testdata/tools/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl similarity index 100% rename from server/testdata/tools/command-r-plus.gotmpl rename to tools/testdata/command-r-plus.gotmpl diff --git a/server/testdata/tools/command-r-plus.out b/tools/testdata/command-r-plus.out similarity index 100% rename from server/testdata/tools/command-r-plus.out rename to tools/testdata/command-r-plus.out diff --git a/server/testdata/tools/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl similarity index 100% rename from server/testdata/tools/firefunction.gotmpl rename to tools/testdata/firefunction.gotmpl diff --git a/server/testdata/tools/firefunction.out b/tools/testdata/firefunction.out similarity index 100% rename from server/testdata/tools/firefunction.out rename to tools/testdata/firefunction.out diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.gotmpl rename to tools/testdata/llama3-groq-tool-use.gotmpl diff --git a/server/testdata/tools/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.out rename to tools/testdata/llama3-groq-tool-use.out diff --git a/tools/testdata/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl new file mode 100644 index 00000000..b132423e --- /dev/null +++ b/tools/testdata/llama3.2.gotmpl @@ -0,0 +1,44 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +{{ if .System }}{{ .System }} +{{- end }} +{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities. +{{- end }}<|eot_id|> +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 }} +{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> +{{- if and $.Tools $last }} + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{{ range $.Tools }} +{{- . }} +{{ end }} +{{ .Content }}<|eot_id|> +{{- else }} + +{{ .Content }}<|eot_id|> +{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> +{{- if .ToolCalls }} +{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} +{{- else }} + +{{ .Content }} +{{- end }}{{ if not $last }}<|eot_id|>{{ end }} +{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> + +{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> + +{{ end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3.2.out b/tools/testdata/llama3.2.out new file mode 100644 index 00000000..a27c6eaf --- /dev/null +++ b/tools/testdata/llama3.2.out @@ -0,0 +1,24 @@ +<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 + +You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. + +You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> + +22<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + +What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/server/testdata/tools/messages.json b/tools/testdata/messages.json similarity index 100% rename from server/testdata/tools/messages.json rename to tools/testdata/messages.json diff --git a/server/testdata/tools/mistral.gotmpl b/tools/testdata/mistral.gotmpl similarity index 100% rename from server/testdata/tools/mistral.gotmpl rename to tools/testdata/mistral.gotmpl diff --git a/server/testdata/tools/mistral.out b/tools/testdata/mistral.out similarity index 100% rename from server/testdata/tools/mistral.out rename to tools/testdata/mistral.out diff --git a/server/testdata/tools/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl similarity index 100% rename from server/testdata/tools/nemotron.gotmpl rename to tools/testdata/nemotron.gotmpl diff --git a/server/testdata/tools/nemotron.out b/tools/testdata/nemotron.out similarity index 100% rename from server/testdata/tools/nemotron.out rename to tools/testdata/nemotron.out diff --git a/tools/testdata/qwen2.5.gotmpl b/tools/testdata/qwen2.5.gotmpl new file mode 100644 index 00000000..cbd7302c --- /dev/null +++ b/tools/testdata/qwen2.5.gotmpl @@ -0,0 +1,51 @@ +{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> +{{- else if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen2.5.out b/tools/testdata/qwen2.5.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen2.5.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/tools/testdata/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl new file mode 100644 index 00000000..26f6656f --- /dev/null +++ b/tools/testdata/qwen3.gotmpl @@ -0,0 +1,50 @@ +{{- if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen3.out b/tools/testdata/qwen3.out new file mode 100644 index 00000000..76bfbfa9 --- /dev/null +++ b/tools/testdata/qwen3.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/server/testdata/tools/tools.json b/tools/testdata/tools.json similarity index 100% rename from server/testdata/tools/tools.json rename to tools/testdata/tools.json diff --git a/server/testdata/tools/xlam.gotmpl b/tools/testdata/xlam.gotmpl similarity index 100% rename from server/testdata/tools/xlam.gotmpl rename to tools/testdata/xlam.gotmpl diff --git a/server/testdata/tools/xlam.out b/tools/testdata/xlam.out similarity index 100% rename from server/testdata/tools/xlam.out rename to tools/testdata/xlam.out diff --git a/tools/tools.go b/tools/tools.go new file mode 100644 index 00000000..509ca90a --- /dev/null +++ b/tools/tools.go @@ -0,0 +1,271 @@ +package tools + +import ( + "encoding/json" + "errors" + "log/slog" + "strings" + gotmpl "text/template" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +var ( + errInvalidToolCall = errors.New("invalid tool call format") + errAccumulateMore = errors.New("need to accumulate more content") +) + +type Parser struct { + parseLeadingJSON bool + prefix string + prefixFound bool + tmpl gotmpl.Template + sb strings.Builder + index int + name string + arguments string + done bool +} + +// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// +// Parameters: +// - s: The string to parse +// - name: The field name from template that identifies the tool call name +// - arguments: The field name from template that identifies the tool call arguments +// +// Returns: +// - []api.ToolCall: The parsed tool calls if successful +// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful +func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { + // Check for balanced braces before attempting to parse + braceCount := 0 + squareCount := 0 + startIndex := -1 + var rawToolCalls []string + s = strings.TrimSpace(s) + + // Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. + trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") + for i, c := range s { + switch c { + case '{': + braceCount++ + if startIndex == -1 { + startIndex = i + } + case '}': + braceCount-- + if braceCount == 0 { + rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) + startIndex = -1 + } + case '[': + if trackSquareBrackets { + squareCount++ + } + case ']': + if trackSquareBrackets { + squareCount-- + } + } + + // Negative means we have an extra closing brace/bracket + if braceCount < 0 || squareCount < 0 { + return nil, errInvalidToolCall + } + } + + // If braces/brackets aren't balanced, need more input + if braceCount > 0 || squareCount > 0 { + return nil, errAccumulateMore + } + + t := strings.TrimSpace(s) + if len(t) == 0 { + return nil, errAccumulateMore + } + // If the input is a single square bracket, it's not a valid tool call + if t[0] == '[' && len(t) == 1 { + return nil, errAccumulateMore + } + + // Attempt full unmarshal of the JSON + var toolCalls []api.ToolCall + for _, rawToolCall := range rawToolCalls { + var resp map[string]any + if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { + continue + } + + // Collect nested objects that could contain tool calls + objs := collect(resp) + if len(objs) == 0 { + continue + } + + // Extract tool calls from objects + for _, kv := range objs { + n, nok := kv[name].(string) + a, aok := kv[arguments].(map[string]any) + if nok && aok { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: n, + Arguments: a, + }, + }) + } else { + slog.Debug("No valid tool call found in object.", "object", kv) + } + } + } + + // Valid JSON, no tool calls found + if len(toolCalls) == 0 { + slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) + return nil, errInvalidToolCall + } + + return toolCalls, nil +} + +// checkPrefix processes a string to find and handle a prefix pattern. +// +// Returns: +// - The processed string with prefix removed if found +// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful +func (p *Parser) checkPrefix(s string) (string, error) { + original := s + if strings.ContainsRune(s, '\n') { + s = strings.ReplaceAll(s, "\n", " ") + } + + if s == "" || p.prefix == "" { + return s, nil + } + + // Check for prefix at start of string + if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { + // Found prefix at start - accumulate for potential tool + p.prefixFound = true + return cut, nil + } + + // Check if prefix overlaps end of string + if idx := suffixOverlap(s, p.prefix); idx != -1 { + // Return everything except overlapping portion + p.sb.Reset() + p.sb.WriteString(s[idx:]) + return original[:idx], errAccumulateMore + } + + // Check if prefix appears in middle of string + if idx := strings.Index(s, p.prefix); idx != -1 { + // Save remainder starting at prefix for next pass + p.sb.Reset() + p.sb.WriteString(strings.TrimSpace(s[idx:])) + // Return everything before prefix + return original[:idx], errAccumulateMore + } + + // No partial prefix found + return s, nil +} + +// Add processes a string input to parse tool calls and content. +// It handles prefix detection and JSON parsing to extract tool calls. +// +// Returns: +// - tools: Any parsed tool calls +// - content: Non-tool call content +func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { + if strings.TrimSpace(s) == "" { + return nil, s + } + if p.done { + if p.index == 0 { + // Return original string if no tool calls found at start + return nil, s + } + // Return empty if no tool calls found after start + return nil, "" + } + p.sb.WriteString(s) + s = p.sb.String() + + // Check for prefix pattern in input + s, err := p.checkPrefix(s) + if err != nil { + // Need more input to complete prefix + return nil, s + } + + // Exit if prefix exists in template, greedy parsing is off, and prefix not found + if !p.parseLeadingJSON && !p.prefixFound { + p.sb.Reset() + return nil, s + } + + toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) + if err != nil { + if errors.Is(err, errAccumulateMore) { + return nil, "" + } + p.sb.Reset() + // Do not try parsing leading JSON if JSON not found + p.parseLeadingJSON = false + if p.prefix == "" { + p.done = true + } + if p.index != 0 && p.prefix == "" { + return nil, "" + } + if p.prefixFound { + // Drop tokens since prefix was found + return nil, "" + } + return nil, s + } + + for _, tc := range toolCalls { + tc.Function.Index = p.index + p.index++ + } + + p.sb.Reset() + return toolCalls, "" +} + +// NewParser creates a new tool call parser from a template. It extracts the tool call format, +// prefix, and field names from the template to use for parsing tool calls from model output. +// +// Returns an error if the template does not contain valid tool call formatting. +func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { + parsed, err := template.Parse(templateToProcess.Root.String()) + if err != nil { + return nil, err + } + + tt, err := toolTemplate(parsed) + if err != nil { + return nil, err + } + + tp := toolPrefix(templateToProcess) + + name, arguments, err := extractToolArgs(tt) + if err != nil { + return nil, err + } + + return &Parser{ + tmpl: *tt, + sb: strings.Builder{}, + prefix: tp, + parseLeadingJSON: true, + name: name, + arguments: arguments, + }, nil +} diff --git a/tools/tools_test.go b/tools/tools_test.go new file mode 100644 index 00000000..1ae3bff8 --- /dev/null +++ b/tools/tools_test.go @@ -0,0 +1,644 @@ +package tools + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +func readFile(t *testing.T, base, name string) *bytes.Buffer { + t.Helper() + + bts, err := os.ReadFile(filepath.Join(base, name)) + if err != nil { + t.Fatal(err) + } + + return bytes.NewBuffer(bts) +} + +func TestParseJSONToolCalls(t *testing.T) { + tests := []struct { + name string + input string + nameField string + argsField string + wantToolCalls []api.ToolCall + wantErr error + prefix string + }{ + { + name: "valid single tool call", + input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: map[string]any{ + "arg1": "value1", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "incomplete JSON", + input: `{"name": "test_tool", "arguments": {"arg1": `, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "invalid JSON", + input: `not json at all`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "missing required fields", + input: `{"other": "field"}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errInvalidToolCall, + prefix: "", + }, + { + name: "multiple tool calls in array", + input: `[ + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + ]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls without array", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "multiple tool calls with text after", + input: ` + {"name": "tool1", "arguments": {"arg1": 1}} text + {"name": "tool2", "arguments": {"arg2": "value"}} text + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + { + name: "second tool call in array", + input: ` + , {"name": "tool2", "arguments": {"arg2": "value"}} + `, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + // a bad JSON would not return any tool calls or content as it would always accumulate more + { + name: "unbalanced square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "incomplete square brackets", + input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, + nameField: "name", + argsField: "arguments", + wantToolCalls: nil, + wantErr: errAccumulateMore, + prefix: "", + }, + { + name: "nested arrays in arguments", + input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, + nameField: "name", + argsField: "arguments", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, + }, + }, + }, + }, + wantErr: nil, + prefix: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) + + if err != tt.wantErr { + t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) + } + + if len(gotCalls) != 0 && tt.wantErr != nil { + t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) + } + + if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { + t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParseToolCalls(t *testing.T) { + p := filepath.Join("testdata") + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + } + t2 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "celsius", + "location": "Toronto, Canada", + }, + }, + } + + cases := []struct { + name string + model string + output string + expectedToolCall []api.ToolCall + expectedTokens string + }{ + { + name: "mistral malformed json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral tool calls with text between no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral valid json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "mistral multiple tool calls with text between and prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, + expectedTokens: "", + }, + { + name: "mistral incomplete json with tool calls prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + }, + { + name: "mistral invalid tool call with explanatory text no prefix", + model: "mistral", + output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "mistral tool calls without prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "command r plus tool calls with json block format", + model: "command-r-plus", + output: "Action: ```json" + ` + [ + { + "tool_name": "get_current_weather", + "parameters": { + "format": "fahrenheit", + "location": "San Francisco, CA" + } + }, + { + "tool_name": "get_current_weather", + "parameters": { + "format": "celsius", + "location": "Toronto, Canada" + } + } + ] + ` + "```", + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "firefunction tool calls with functools prefix", + model: "firefunction", + output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "llama3 groq single tool call with xml tags", + model: "llama3-groq-tool-use", + output: ` + {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} + `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "xlam tool calls with wrapper object", + model: "xlam", + output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 single tool call with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "qwen2.5 multiple tool calls with and without prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 plain text response no tool calls", + model: "qwen2.5", + output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + expectedToolCall: []api.ToolCall{}, + expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + }, + { + name: "qwen2.5 tool calls with trailing text", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens after call", + }, + { + name: "qwen2.5 tool calls with initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "qwen2.5 tool calls with prefix and trailing text", + model: "qwen2.5", + output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls with prefix and initial text", + model: "qwen2.5", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens before call", + }, + { + name: "qwen2.5 tool calls without and with prefix", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls without and with prefix and text between", + model: "qwen2.5", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens between", + }, + { + name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", + model: "qwen2.5", + output: `hi [{"options": "foo"}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `hi [{"options": "foo"}]`, + }, + { + name: "qwen2.5 tool calls with prefix and invalid tool call", + model: "qwen2.5", + output: ` [{"options": "foo"}] `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", + model: "qwen3", + output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", + model: "qwen3", + output: `Okay, let me think what tool we should use... { "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + }, + { + name: "qwen3 empty think prefix without tool prefix and invalid tool call", + model: "qwen3", + output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 empty think prefix with tool prefix and valid tool call", + model: "qwen3", + output: `{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", + model: "qwen3", + output: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "qwen3 invalid tool call with malformed tool prefix", + model: "qwen3", + output: ``, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, + { + name: "model with prefix in template, no prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model with prefix in template, prefix in output", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, single tool call", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + { + name: "model without prefix in template, prefix in output", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, no prefix in output, tokens before", + model: "qwen2.5", + output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model with prefix in template, prefix in output, tokens after", + model: "qwen2.5", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens after", + model: "llama3.2", + output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "model without prefix in template, no prefix in output, tokens before", + model: "llama3.2", + output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, + { + name: "model without prefix in template, prefix in output, tokens after", + model: "llama3.2", + output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + }, + } + + var tools []api.Tool + if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { + t.Fatal(err) + } + + var messages []api.Message + if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { + t.Fatal(err) + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + if err != nil { + t.Fatal(err) + } + + t.Run("template", func(t *testing.T) { + actual := &bytes.Buffer{} // Create new buffer for each test + if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("parse", func(t *testing.T) { + tp, err := NewParser(tmpl.Template) + if err != nil { + t.Fatal(err) + } + got := []api.ToolCall{} + var gotTokens strings.Builder + + tokens := strings.Fields(tt.output) + for _, tok := range tokens { + s := " " + tok + + toolCalls, content := tp.Add(s) + if len(content) > 0 { + gotTokens.WriteString(content) + } else if len(toolCalls) > 0 { + got = append(got, toolCalls...) + } + } + + // Compare tool calls if we expect any + if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { + t.Errorf("tool calls mismatch (-got +want):\n%s", diff) + } + + // Compare tokens if we expect any + stripped := strings.TrimSpace(gotTokens.String()) + if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { + t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) + t.Errorf("tokens mismatch (-got +want):\n%s", diff) + } + }) + }) + } +} diff --git a/tools/tools_utils.go b/tools/tools_utils.go new file mode 100644 index 00000000..48531b78 --- /dev/null +++ b/tools/tools_utils.go @@ -0,0 +1,227 @@ +package tools + +import ( + "bytes" + "encoding/json" + "errors" + "log/slog" + "slices" + "strings" + gotmpl "text/template" + "text/template/parse" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. +// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any +// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. +// +// Returns: +// - string: The extracted text following the first ".ToolCalls" condition found +// - bool: Whether a ".ToolCalls" condition was found in the template +func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if isToolCallsNode(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + } + } + + walk(tmpl.Tree.Root.Nodes) + return result, found +} + +// isToolCallsNode detects if a node's condition includes ".ToolCalls" +func isToolCallsNode(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +func toolPrefix(tmpl *gotmpl.Template) string { + tokenText, ok := extractToolCallsFormat(tmpl) + if !ok { + return "" + } + tokenText = strings.TrimSpace(tokenText) + tokenText = strings.ReplaceAll(tokenText, "\r", "") + tokenText = strings.ReplaceAll(tokenText, "\n", " ") + + return tokenText +} + +// toolTemplate creates a subtree from the node that ranges over .ToolCalls +// +// Returns: +// - *gotmpl.Template: The subtree containing the .ToolCalls range +// - error: Error if parsing failed +func toolTemplate(t *template.Template) (*gotmpl.Template, error) { + tmpl := t.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, errors.New("failed to find tool template") + } + + return tmpl, nil +} + +// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins +// +// Returns: +// - int: The starting index in s where the suffix overlap begins +func suffixOverlap(s, prefix string) int { + max := min(len(prefix), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, prefix[:i]) { + return len(s) - i + } + } + return -1 +} + +// extractToolArgs executes a template with a known tool call format to extract the name and arguments +// +// Returns: +// - string: The name of the tool call +// - string: The arguments of the tool call +// - error: Error if parsing failed +func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, + }, + }, + }, + }); err != nil { + return "", "", err + } + + var obj any + err = json.Unmarshal(b.Bytes(), &obj) + if err != nil { + return "", "", err + } + + var objs []map[string]any + switch v := obj.(type) { + case map[string]any: + objs = []map[string]any{v} + case []map[string]any: + objs = v + case []any: + objs = collect(v) + } + if len(objs) == 0 { + return "", "", errors.New("no template objects found") + } + + // find the keys that correspond to the name and arguments fields + for k, v := range objs[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) + return "", "", errors.New("missing required fields in tool call template") + } + + return name, arguments, nil +} + +// collect recursively traverses an object to collect all nested maps +// +// Returns: +// - []map[string]any: A slice of all nested maps found in the object +func collect(obj any) []map[string]any { + var all []map[string]any + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) + } + default: + return nil + } + + return all +} diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go new file mode 100644 index 00000000..769183b7 --- /dev/null +++ b/tools/tools_utils_test.go @@ -0,0 +1,464 @@ +package tools + +import ( + "testing" + gotmpl "text/template" + + "github.com/ollama/ollama/template" +) + +func TestExtractToolCallsFormat(t *testing.T) { + cases := []struct { + name string + template string + want string + found bool + }{ + { + name: "nil template", + template: "", + want: "", + found: false, + }, + { + name: "basic tool call with text", + template: "{{if .ToolCalls}}Hello world{{end}}", + want: "Hello world", + found: true, + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json\n", + found: true, + }, + { + name: "tool call in range", + template: "{{range .ToolCalls}}tool: {{.}}{{end}}", + want: "", + found: false, + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + found: true, + }, + { + name: "nested if without tool calls", + template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", + want: "", + found: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got, found := extractToolCallsFormat(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + if found != tc.found { + t.Errorf("got found %v, want %v", found, tc.found) + } + }) + } +} + +func TestToolPrefix(t *testing.T) { + cases := []struct { + name string + template string + want string + }{ + { + name: "basic tool call with action prefix", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action: ```json", + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools[", + }, + { + name: "tool call with angle brackets", + template: "{{if .ToolCalls}}Hello, world! {{end}}", + want: "Hello, world! ", + }, + { + name: "multiple tool call formats", + template: "{{if .ToolCalls}}[tool_call] {{end}}", + want: "[tool_call] ", + }, + { + name: "single angle bracket tool call", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + { + name: "incomplete angle bracket after tool call", + template: "{{if .ToolCalls}}[tool_call] <{{end}}", + want: "[tool_call] <", + }, + { + name: "angle bracket prefix with tool call", + template: "{{if .ToolCalls}}> {{end}}", + want: "> ", + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL] [", + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL][", + }, + { + name: "tool call with pipe delimiters", + template: "{{if .ToolCalls}}<|tool_call|>{{end}}", + want: "<|tool_call|>", + }, + { + name: "tool with no prefix", + template: "{{if .ToolCalls}}{{end}}", + want: "", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + got := toolPrefix(tmpl) + if got != tt.want { + t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) + } + }) + } +} + +func TestToolTemplate(t *testing.T) { + cases := []struct { + name string + template string + want bool + }{ + { + name: "basic tool call range", + template: "{{range .ToolCalls}}test{{end}}", + want: true, + }, + { + name: "no tool calls", + template: "{{range .Other}}test{{end}}", + want: false, + }, + { + name: "nested tool calls", + template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", + want: true, + }, + { + name: "empty template", + template: "", + want: false, + }, + { + name: "tool calls in if statement", + template: "{{if .ToolCalls}}test{{end}}", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + parsed, err := template.Parse(tmpl.Root.String()) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + _, err = toolTemplate(parsed) + if err != nil && tt.want { + t.Errorf("toolTemplate() = %v; want %v", err, tt.want) + } + }) + } +} + +func TestSuffixOverlap(t *testing.T) { + cases := []struct { + name string + s string + d string + want int + }{ + { + name: "no overlap", + s: "hello world", + d: "", + want: -1, + }, + { + name: "full overlap", + s: "", + d: "", + want: 0, + }, + { + name: "partial overlap", + s: "text ", + d: "", + want: 5, + }, + { + name: "delimiter longer than string", + s: "", + d: "", + want: -1, + }, + { + name: "empty string", + s: "", + d: "", + want: -1, + }, + { + name: "empty delimiter", + s: "", + d: "", + want: -1, + }, + { + name: "single char overlap", + s: "test<", + d: "", + want: 4, + }, + { + name: "partial tool call", + s: "hello ", + want: 6, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := suffixOverlap(tt.s, tt.d) + if got != tt.want { + t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) + } + }) + } +} + +func TestExtractToolArgs(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool + }{ + { + name: "basic tool call with text after", + template: `{{if .ToolCalls}}tool response{{end}}`, + want: "tool response", + ok: true, + }, + { + name: "tool call with mixed content after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool call with no text after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "nested tool call", + template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "no tool calls", + template: `{{if .Something}}no tools here{{end}}`, + want: "", + ok: false, + }, + { + name: "empty template", + template: ``, + want: "", + ok: false, + }, + { + name: "multiple tool calls sections", + template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, + want: "first", + ok: true, + }, + { + name: "range over tool calls", + template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with pipe delimiters", + template: `{{if .ToolCalls}}<|tool|>{{end}}`, + want: "<|tool|>", + ok: true, + }, + { + name: "tool calls with nested template", + template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with whitespace variations", + template: `{{if .ToolCalls}} tool {{end}}`, + want: " tool ", + ok: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + got, ok := extractToolCallsFormat(tmpl) + if got != tt.want { + t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + } + if ok != tt.ok { + t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + } + }) + } +} + +func TestCollect(t *testing.T) { + cases := []struct { + name string + obj any + want []map[string]any + }{ + { + name: "simple map", + obj: map[string]any{ + "key": "value", + }, + want: []map[string]any{ + {"key": "value"}, + }, + }, + { + name: "nested map", + obj: map[string]any{ + "outer": map[string]any{ + "inner": "value", + }, + }, + want: []map[string]any{ + {"outer": map[string]any{"inner": "value"}}, + {"inner": "value"}, + }, + }, + { + name: "array of maps", + obj: []any{ + map[string]any{"key1": "val1"}, + map[string]any{"key2": "val2"}, + }, + want: []map[string]any{ + {"key1": "val1"}, + {"key2": "val2"}, + }, + }, + { + name: "deeply nested", + obj: map[string]any{ + "l1": map[string]any{ + "l2": map[string]any{ + "l3": "value", + }, + }, + }, + want: []map[string]any{ + {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, + {"l2": map[string]any{"l3": "value"}}, + {"l3": "value"}, + }, + }, + { + name: "non-map value", + obj: "string", + want: nil, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := collect(tt.obj) + if len(got) != len(tt.want) { + t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) + return + } + + // Compare each map in the result + for i := range tt.want { + if !mapsEqual(got[i], tt.want[i]) { + t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +// mapsEqual compares two maps for deep equality +func mapsEqual(m1, m2 map[string]any) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + v2, ok := m2[k] + if !ok { + return false + } + switch val1 := v1.(type) { + case map[string]any: + val2, ok := v2.(map[string]any) + if !ok || !mapsEqual(val1, val2) { + return false + } + default: + if v1 != v2 { + return false + } + } + } + return true +}