From 071a9872cb76f07d09dc8a3c65046d35d921f4e6 Mon Sep 17 00:00:00 2001 From: Leandro Borges Ferreira Date: Mon, 31 Mar 2025 02:28:06 +0200 Subject: [PATCH 01/17] readme: add Writeopia to community integrations (#10042) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3afd83c3..c3d49105 100644 --- a/README.md +++ b/README.md @@ -395,6 +395,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) - [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history - [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).) +- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) ### Cloud From 5d097277ef8b08c86f354b54596976869998257d Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 27 Mar 2025 14:00:05 -0700 Subject: [PATCH 02/17] ollamarunner: Ensure batch size limits are not exceeded With the llama runner, we can generate up to NUM_PARALLEL batches at once, which will then get broken up to into individual batches to get executed by llama.cpp (i.e. we add up to 2048 tokens and this gets split into 4 batches of 512 tokens at default settings). This splitting can improve parallelism on multi-GPU systems because the individual batches can move though the pipeline without blocking on the first one to fully complete. However, we don't yet support this in the Ollama runner, partially because it makes it hard to enforce model-specified batch constraints, which didn't exist previously. The result is that we will try to execute the full, unsplit batch. This could result in out of memory or insufficient KV cache space errors. This triggers batch breaking when the total inputs from all sequences exceeds the batch size, rather than per-sequence. In order to ensure fairness, it also reintroduces round-robinning around sequences so that we don't let one busy sequence starve the others. --- runner/ollamarunner/runner.go | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 31d20db8..6d20fa85 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -267,6 +267,9 @@ type Server struct { // KV cache cache *InputCache + // next sequence for prompt processing to avoid starvation + nextSeq int + // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -351,14 +354,19 @@ func (s *Server) processBatch() error { var batchInputs []int32 var batch input.Batch - for i, seq := range s.seqs { + resumeSeq := -1 + seqIdx := s.nextSeq - 1 + for range s.seqs { + seqIdx = (seqIdx + 1) % len(s.seqs) + seq := s.seqs[seqIdx] + if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(i, "limit") + s.removeSequence(seqIdx, "limit") continue } @@ -369,16 +377,23 @@ func (s *Server) processBatch() error { batchSize := s.batchSize - for j, inp := range seq.inputs { + for i, inp := range seq.inputs { // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this - // will cause a break if we have pending inputs. + // will cause a break if we have existing inputs. minBatch := 1 + inp.SameBatch if minBatch > batchSize { batchSize = minBatch } - if len(seq.pendingInputs)+minBatch > batchSize { + // Stop if the required batch would put us over the total batch size (including tokens + // added by other sequences). If we haven't been able to add anything yet then pick up + // here again for the next batch to avoid starvation, though we can opportunistically + // check if other sequences can still squeeze something in. + if len(batchInputs)+minBatch > batchSize { + if len(seq.pendingInputs) == 0 && resumeSeq == -1 { + resumeSeq = seqIdx + } break } @@ -405,7 +420,7 @@ func (s *Server) processBatch() error { batch.Sequences = append(batch.Sequences, seq.cache.Id) seq.iBatch = len(batch.Outputs) - if j+1 == len(seq.inputs) { + if i+1 == len(seq.inputs) { batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) @@ -414,6 +429,12 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } + if resumeSeq != -1 { + s.nextSeq = resumeSeq + } else { + s.nextSeq = seqIdx + 1 + } + if len(batchInputs) == 0 { return nil } From b2a465296d7131ca440fd81c1bee888f4103a585 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 14 Mar 2025 17:24:46 -0700 Subject: [PATCH 03/17] runner: Release semaphore and improve error messages on failures If we have an error after creating a new sequence but before finding a slot for it, we return without releasing the semaphore. This reduces our parallel sequences and eventually leads to deadlock. In practice this should never happen because once we have acquired the semaphore, we should always be able to find a slot. However, the code is clearly not correct. --- runner/llamarunner/runner.go | 8 ++++++-- runner/ollamarunner/runner.go | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 83802d60..ee5d47f6 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -599,7 +599,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") } else { - slog.Error("Failed to acquire semaphore", "error", err) + http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError) } return } @@ -611,6 +611,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() + s.seqsSem.Release(1) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) return } @@ -626,6 +627,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Unlock() if !found { + s.seqsSem.Release(1) http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } @@ -691,7 +693,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { if errors.Is(err, context.Canceled) { slog.Info("aborting embeddings request due to client closing the connection") } else { - slog.Error("Failed to acquire semaphore", "error", err) + http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError) } return } @@ -703,6 +705,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() + s.seqsSem.Release(1) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) return } @@ -715,6 +718,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { s.mu.Unlock() if !found { + s.seqsSem.Release(1) http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 6d20fa85..bc7a07ed 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -609,7 +609,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") } else { - slog.Error("Failed to acquire semaphore", "error", err) + http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError) } return } @@ -621,6 +621,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) if err != nil { s.mu.Unlock() + s.seqsSem.Release(1) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) return } @@ -634,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Unlock() if !found { + s.seqsSem.Release(1) http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } From ef27d52e7957e00fe664e7dc73cff2714f85468f Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 30 Mar 2025 23:54:54 -0700 Subject: [PATCH 04/17] server/internal/client/ollama: cache completed chunks (#9933) This change adds tracking of download chunks during the pull process so that subsequent pulls can skip downloading already completed chunks. This works across restarts of ollama. Currently, download state will be lost if a prune is triggered during a pull (e.g. restart or remove). This issue should be addressed in a follow-up PR. --- server/internal/client/ollama/registry.go | 95 ++- .../internal/client/ollama/registry_test.go | 651 ++++++++++-------- 2 files changed, 443 insertions(+), 303 deletions(-) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 665defd5..409932bf 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { return err } -func canRetry(err error) bool { - var re *Error - if !errors.As(err, &re) { - return false - } - return re.Status >= 500 -} - // trackingReader is an io.Reader that tracks the number of bytes read and // calls the update function with the layer, the number of bytes read. // @@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error { break } + cacheKey := fmt.Sprintf( + "v1 pull chunksum %s %s %d-%d", + l.Digest, + cs.Digest, + cs.Chunk.Start, + cs.Chunk.End, + ) + cacheKeyDigest := blob.DigestFromBytes(cacheKey) + _, err := c.Get(cacheKeyDigest) + if err == nil { + received.Add(cs.Chunk.Size()) + t.update(l, cs.Chunk.Size(), ErrCached) + continue + } + wg.Add(1) g.Go(func() (err error) { defer func() { if err == nil { + // Ignore cache key write errors for now. We've already + // reported to trace that the chunk is complete. + // + // Ideally, we should only report completion to trace + // after successful cache commit. This current approach + // works but could trigger unnecessary redownloads if + // the checkpoint key is missing on next pull. + // + // Not incorrect, just suboptimal - fix this in a + // future update. + _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + received.Add(cs.Chunk.Size()) } else { - err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) + t.update(l, 0, err) } wg.Done() }() @@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { return err } if received.Load() != expected { - return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected) + return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected) } md := blob.DigestFromBytes(m.Data) @@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer { return nil } +func (m *Manifest) All() iter.Seq[*Layer] { + return func(yield func(*Layer) bool) { + if !yield(m.Config) { + return + } + for _, l := range m.Layers { + if !yield(l) { + return + } + } + } +} + +func (m *Manifest) Size() int64 { + var size int64 + if m.Config != nil { + size += m.Config.Size + } + for _, l := range m.Layers { + size += l.Size + } + return size +} + // MarshalJSON implements json.Marshaler. // // NOTE: It adds an empty config object to the manifest, which is required by @@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se return } - // A chunksums response is a sequence of chunksums in a - // simple, easy to parse line-oriented format. + // The response is a sequence of chunksums. // - // Example: + // Chunksums are chunks of a larger blob that can be + // downloaded and verified independently. // - // >> GET /v2///chunksums/ + // The chunksums endpoint is a GET request that returns a + // sequence of chunksums in the following format: // - // << HTTP/1.1 200 OK - // << Content-Location: - // << - // << - - // << ... + // > GET /v2///chunksums/ // - // The blobURL is the URL to download the chunks from. + // < HTTP/1.1 200 OK + // < Content-Location: + // < + // < - + // < ... + // + // The is the URL to download the chunks from and + // each is the digest of the chunk, and - + // is the range the chunk in the blob. + // + // Ranges may be used directly in Range headers like + // "bytes=-". + // + // The chunksums returned are guaranteed to be contiguous and + // include all bytes of the layer. If the stream is cut short, + // clients should retry. chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s", scheme, diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index f8136c06..80d39b76 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -9,17 +9,14 @@ import ( "fmt" "io" "io/fs" - "math/rand/v2" + "net" "net/http" "net/http/httptest" "os" - "path" "reflect" - "slices" "strings" - "sync" + "sync/atomic" "testing" - "time" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/testutil" @@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) { } } -func checkNotExist(t *testing.T, err error) { - t.Helper() - if !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("err = %v; want fs.ErrNotExist", err) - } -} - func TestRegistryPullInvalidName(t *testing.T) { - rc, _ := newClient(t, nil) + rc, _ := newRegistryClient(t, nil) err := rc.Pull(t.Context(), "://") if !errors.Is(err, ErrNameInvalid) { t.Errorf("err = %v; want %v", err, ErrNameInvalid) @@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) { } for _, resp := range cases { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, resp) }) - err := rc.Pull(t.Context(), "x") + err := rc.Pull(t.Context(), "http://example.com/a/b") if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want invalid manifest", err) } } } -func TestRegistryPullNotCached(t *testing.T) { - check := testutil.Checker(t) - - var c *blob.DiskCache - var rc *Registry - - d := blob.DigestFromBytes("some data") - rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/blobs/") { - io.WriteString(w, "some data") - return - } - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d) - }) - - // Confirm that the layer does not exist locally - _, err := rc.ResolveLocal("model") - checkNotExist(t, err) - - _, err = c.Get(d) - checkNotExist(t, err) - - err = rc.Pull(t.Context(), "model") - check(err) - - mw, err := rc.Resolve(t.Context(), "model") - check(err) - mg, err := rc.ResolveLocal("model") - check(err) - if !reflect.DeepEqual(mw, mg) { - t.Errorf("mw = %v; mg = %v", mw, mg) - } - - // Confirm successful download - info, err := c.Get(d) - check(err) - if info.Digest != d { - t.Errorf("info.Digest = %v; want %v", info.Digest, d) - } - if info.Size != 9 { - t.Errorf("info.Size = %v; want %v", info.Size, 9) - } - - data, err := os.ReadFile(c.GetFile(d)) - check(err) - if string(data) != "some data" { - t.Errorf("data = %q; want %q", data, "exists") - } -} - -func TestRegistryPullCached(t *testing.T) { - cached := blob.DigestFromBytes("exists") - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/blobs/") { - w.WriteHeader(499) // should not be called - return - } - if strings.Contains(r.URL.Path, "/manifests/") { - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached) - } - }) - - var errs []error - var reads []int64 - ctx := WithTrace(t.Context(), &Trace{ - Update: func(d *Layer, n int64, err error) { - t.Logf("update %v %d %v", d, n, err) - reads = append(reads, n) - errs = append(errs, err) - }, - }) - - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - - err := rc.Pull(ctx, "single") - testutil.Check(t, err) - - want := []int64{0, 6} - if !errors.Is(errors.Join(errs...), ErrCached) { - t.Errorf("errs = %v; want %v", errs, ErrCached) - } - if !slices.Equal(reads, want) { - t.Errorf("pairs = %v; want %v", reads, want) - } -} - -func TestRegistryPullManifestNotFound(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - }) - err := rc.Pull(t.Context(), "notfound") - checkErrCode(t, err, 404, "") -} - -func TestRegistryPullResolveRemoteError(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) - }) - err := rc.Pull(t.Context(), "single") - checkErrCode(t, err, 500, "an_error") -} - -func TestRegistryPullResolveRoundtripError(t *testing.T) { - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/manifests/") { - w.WriteHeader(499) // force RoundTrip error - return - } - }) - err := rc.Pull(t.Context(), "single") - if !errors.Is(err, errRoundTrip) { - t.Errorf("err = %v; want %v", err, errRoundTrip) - } -} - -// TestRegistryPullMixedCachedNotCached tests that cached layers do not -// interfere with pulling layers that are not cached -func TestRegistryPullMixedCachedNotCached(t *testing.T) { - x := blob.DigestFromBytes("xxxxxx") - e := blob.DigestFromBytes("exists") - y := blob.DigestFromBytes("yyyyyy") - - for i := range 10 { - t.Logf("iteration %d", i) - - digests := []blob.Digest{x, e, y} - - rand.Shuffle(len(digests), func(i, j int) { - digests[i], digests[j] = digests[j], digests[i] - }) - - manifest := fmt.Sprintf(`{ - "layers": [ - {"digest":"%s","size":6}, - {"digest":"%s","size":6}, - {"digest":"%s","size":6} - ] - }`, digests[0], digests[1], digests[2]) - - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { - switch path.Base(r.URL.Path) { - case "latest": - io.WriteString(w, manifest) - case x.String(): - io.WriteString(w, "xxxxxx") - case e.String(): - io.WriteString(w, "exists") - case y.String(): - io.WriteString(w, "yyyyyy") - default: - panic(fmt.Sprintf("unexpected request: %v", r)) - } - }) - - ctx := WithTrace(t.Context(), &Trace{ - Update: func(l *Layer, n int64, err error) { - t.Logf("update %v %d %v", l, n, err) - }, - }) - - // Check that we pull all layers that we can. - - err := rc.Pull(ctx, "mixed") - if err != nil { - t.Fatal(err) - } - - for _, d := range digests { - info, err := c.Get(d) - if err != nil { - t.Fatalf("Get(%v): %v", d, err) - } - if info.Size != 6 { - t.Errorf("info.Size = %v; want %v", info.Size, 6) - } - } - } -} - func TestRegistryResolveByDigest(t *testing.T) { check := testutil.Checker(t) @@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) { testutil.Check(t, err) } -func TestCanRetry(t *testing.T) { - cases := []struct { - err error - want bool - }{ - {nil, false}, - {errors.New("x"), false}, - {ErrCached, false}, - {ErrManifestInvalid, false}, - {ErrNameInvalid, false}, - {&Error{Status: 100}, false}, - {&Error{Status: 500}, true}, - } - for _, tt := range cases { - if got := canRetry(tt.err); got != tt.want { - t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want) - } - } -} - func TestErrorUnmarshal(t *testing.T) { cases := []struct { name string @@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) { func TestUnlink(t *testing.T) { t.Run("found by name", func(t *testing.T) { - rc, _ := newClient(t, nil) + check := testutil.Checker(t) + + rc, _ := newRegistryClient(t, nil) + // make a blob and link it + d := blob.DigestFromBytes("{}") + err := blob.PutBytes(rc.Cache, d, "{}") + check(err) + err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d) + check(err) // confirm linked - _, err := rc.ResolveLocal("single") - if err != nil { - t.Errorf("unexpected error: %v", err) - } + _, err = rc.ResolveLocal("single") + check(err) // unlink _, err = rc.Unlink("single") - testutil.Check(t, err) + check(err) // confirm unlinked _, err = rc.ResolveLocal("single") @@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) { } }) t.Run("not found by name", func(t *testing.T) { - rc, _ := newClient(t, nil) + rc, _ := newRegistryClient(t, nil) ok, err := rc.Unlink("manifestNotFound") if err != nil { t.Fatal(err) @@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) { }) } -func TestPullChunksums(t *testing.T) { - check := testutil.Checker(t) +// Many tests from here out, in this file are based on a single blob, "abc", +// with the checksum of its sha256 hash. The checksum is: +// +// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad +// +// Using the literal value instead of a constant with fmt.Xprintf calls proved +// to be the most readable and maintainable approach. The sum is consistently +// used in the tests and unique so searches do not yield false positives. - content := "hello" - var chunksums string - contentDigest := func() blob.Digest { - return blob.DigestFromBytes(content) +func checkRequest(t *testing.T, req *http.Request, method, path string) { + t.Helper() + if got := req.URL.Path; got != path { + t.Errorf("URL = %q, want %q", got, path) } - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { - switch { - case strings.Contains(r.URL.Path, "/manifests/latest"): - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content)) - case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()): - loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest()) - w.Header().Set("Content-Location", loc) - io.WriteString(w, chunksums) - case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()): - http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content)) - default: - t.Errorf("unexpected request: %v", r) - http.NotFound(w, r) - } - }) + if req.Method != method { + t.Errorf("Method = %q, want %q", req.Method, method) + } +} - rc.MaxStreams = 1 // prevent concurrent chunk downloads - rc.ChunkingThreshold = 1 // for all blobs to be chunked +func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) { + s := httptest.NewServer(h) + t.Cleanup(s.Close) + cache, err := blob.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } - var mu sync.Mutex - var reads []int64 ctx := WithTrace(t.Context(), &Trace{ Update: func(l *Layer, n int64, err error) { - t.Logf("Update: %v %d %v", l, n, err) - mu.Lock() - reads = append(reads, n) - mu.Unlock() + t.Log("trace:", l.Digest.Short(), n, err) }, }) - chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n", - blob.DigestFromBytes("hel"), - blob.DigestFromBytes("lo"), - ) - err := rc.Pull(ctx, "test") - check(err) - wantReads := []int64{ - 0, // initial signaling of layer pull starting - 3, // first chunk read - 2, // second chunk read - } - if !slices.Equal(reads, wantReads) { - t.Errorf("reads = %v; want %v", reads, wantReads) + rc := &Registry{ + Cache: cache, + HTTPClient: &http.Client{Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, s.Listener.Addr().String()) + }, + }}, } + return rc, ctx +} - mw, err := rc.Resolve(t.Context(), "test") - check(err) - mg, err := rc.ResolveLocal("test") - check(err) - if !reflect.DeepEqual(mw, mg) { - t.Errorf("mw = %v; mg = %v", mw, mg) - } - for i := range mg.Layers { - _, err = c.Get(mg.Layers[i].Digest) - if err != nil { - t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err) +func TestPullChunked(t *testing.T) { + var steps atomic.Int64 + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch steps.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3, 4: + checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + switch rng := r.Header.Get("Range"); rng { + case "bytes=0-1": + io.WriteString(w, "ab") + case "bytes=2-2": + t.Logf("writing c") + io.WriteString(w, "c") + default: + t.Errorf("unexpected range %q", rng) + } + default: + t.Errorf("unexpected steps %d: %v", steps.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) } - } + }) - // missing chunks - content = "llama" - chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll")) - err = rc.Pull(ctx, "missingchunks") - if err == nil { - t.Error("expected error because of missing chunks") + c.ChunkingThreshold = 1 // force chunking + + err := c.Pull(ctx, "http://o.com/library/abc") + testutil.Check(t, err) + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + testutil.Check(t, err) + + if g := steps.Load(); g != 4 { + t.Fatalf("got %d steps, want 4", g) + } +} + +func TestPullCached(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + }) + + check := testutil.Checker(t) + + // Premeptively cache the blob + d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + check(err) + err = blob.PutBytes(c.Cache, d, []byte("abc")) + check(err) + + // Pull only the manifest, which should be enough to resolve the cached blob + err = c.Pull(ctx, "http://o.com/library/abc") + check(err) +} + +func TestPullManifestError(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + w.WriteHeader(http.StatusNotFound) + io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`) + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + if err == nil { + t.Fatalf("expected error") + } + var got *Error + if !errors.Is(err, ErrModelNotFound) { + t.Fatalf("err = %v, want %v", got, ErrModelNotFound) + } +} + +func TestPullLayerError(t *testing.T) { + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `!`) + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + if err == nil { + t.Fatalf("expected error") + } + var want *json.SyntaxError + if !errors.As(err, &want) { + t.Fatalf("err = %T, want %T", err, want) + } +} + +func TestPullLayerChecksumError(t *testing.T) { + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3: + w.WriteHeader(http.StatusNotFound) + io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`) + case 4: + io.WriteString(w, "c") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.MaxStreams = 1 + c.ChunkingThreshold = 1 // force chunking + + var written atomic.Int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + written.Add(n) + }, + }) + + err := c.Pull(ctx, "http://o.com/library/abc") + var got *Error + if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" { + t.Fatalf("err = %v, want %v", err, got) + } + + if g := written.Load(); g != 1 { + t.Fatalf("wrote %d bytes, want 1", g) + } +} + +func TestPullChunksumStreamError(t *testing.T) { + var step atomic.Int64 + c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + + // Write one valid chunksum and one invalid chunksum + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid + fmt.Fprint(w, "sha256:!") // invalid + case 3: + io.WriteString(w, "ab") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.ChunkingThreshold = 1 // force chunking + + got := c.Pull(ctx, "http://o.com/library/abc") + if !errors.Is(got, ErrIncomplete) { + t.Fatalf("err = %v, want %v", got, ErrIncomplete) + } +} + +type flushAfterWriter struct { + w io.Writer +} + +func (f *flushAfterWriter) Write(p []byte) (n int, err error) { + n, err = f.w.Write(p) + f.w.(http.Flusher).Flush() // panic if not a flusher + return +} + +func TestPullChunksumStreaming(t *testing.T) { + csr, csw := io.Pipe() + defer csw.Close() + + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing + _, err := io.Copy(fw, csr) + if err != nil { + t.Errorf("copy: %v", err) + } + case 3: + io.WriteString(w, "ab") + case 4: + io.WriteString(w, "c") + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.ChunkingThreshold = 1 // force chunking + + update := make(chan int64, 1) + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + if n > 0 { + update <- n + } + }, + }) + + errc := make(chan error, 1) + go func() { + errc <- c.Pull(ctx, "http://o.com/library/abc") + }() + + // Send first chunksum and ensure it kicks off work immediately + fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab")) + if g := <-update; g != 2 { + t.Fatalf("got %d, want 2", g) + } + + // now send the second chunksum and ensure it kicks off work immediately + fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c")) + if g := <-update; g != 1 { + t.Fatalf("got %d, want 1", g) + } + csw.Close() + testutil.Check(t, <-errc) +} + +func TestPullChunksumsCached(t *testing.T) { + var step atomic.Int64 + c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + switch step.Add(1) { + case 1: + checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest") + io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`) + case 2: + w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) + fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c")) + case 3, 4: + switch rng := r.Header.Get("Range"); rng { + case "bytes=0-1": + io.WriteString(w, "ab") + case "bytes=2-2": + io.WriteString(w, "c") + default: + t.Errorf("unexpected range %q", rng) + } + default: + t.Errorf("unexpected steps %d: %v", step.Load(), r) + http.Error(w, "unexpected steps", http.StatusInternalServerError) + } + }) + + c.MaxStreams = 1 // force serial processing of chunksums + c.ChunkingThreshold = 1 // force chunking + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + // Cancel the pull after the first chunksum is processed, but before + // the second chunksum is processed (which is waiting because + // MaxStreams=1). This should cause the second chunksum to error out + // leaving the blob incomplete. + ctx = WithTrace(ctx, &Trace{ + Update: func(l *Layer, n int64, err error) { + if n > 0 { + cancel() + } + }, + }) + err := c.Pull(ctx, "http://o.com/library/abc") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want %v", err, context.Canceled) + } + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want nil", err) + } + + // Reset state and pull again to ensure the blob chunks that should + // have been cached are, and the remaining chunk was downloaded, making + // the blob complete. + step.Store(0) + var written atomic.Int64 + var cached atomic.Int64 + ctx = WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Log("trace:", l.Digest.Short(), n, err) + if errors.Is(err, ErrCached) { + cached.Add(n) + } + written.Add(n) + }, + }) + + check := testutil.Checker(t) + + err = c.Pull(ctx, "http://o.com/library/abc") + check(err) + + _, err = c.Cache.Resolve("o.com/library/abc:latest") + check(err) + + if g := written.Load(); g != 3 { + t.Fatalf("wrote %d bytes, want 3", g) + } + if g := cached.Load(); g != 2 { // "ab" should have been cached + t.Fatalf("cached %d bytes, want 3", g) } } From 66b253923891d41a31d28531e9db5efccf53e1d0 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 31 Mar 2025 12:54:45 -0700 Subject: [PATCH 05/17] runner: clear cache when shift is not possible (#9433) Clear KV cache when shift operation is not supported by model. Added KvCacheCanShift() check to handle models that can't perform cache shifts, falling back to full cache clear while preserving logical token history to maintain expected behavior when context window fills up. --- llama/llama.go | 4 ++ runner/llamarunner/cache.go | 55 +++++++++++++++---- runner/llamarunner/runner.go | 10 +++- runner/ollamarunner/cache.go | 24 +++++++- runner/ollamarunner/cache_test.go | 91 +++++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 10 +++- 6 files changed, 180 insertions(+), 14 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index a026bee2..e8cdafe7 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() { C.llama_kv_cache_defrag(c.c) } +func (c *Context) KvCacheCanShift() bool { + return bool(C.llama_kv_cache_can_shift(c.c)) +} + // Get the embeddings for a sequence id func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index d29e94b6..2e55b09d 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { return discard } -// Frees up space in the KV cache by deleting the oldest half of history and shifting -// the newest half into that space (saving numKeep inputs at the beginning). +type ErrReprocessInputs struct { + Inputs []input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + +// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history +// and shifting the newest half into that space (saving numKeep inputs at the beginning). // // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { @@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) } - discard := c.ShiftDiscard(len(slot.Inputs), numKeep) + inputLen := len(slot.Inputs) + discard := c.ShiftDiscard(inputLen, numKeep) if discard <= 0 { return nil @@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models - if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) - } - c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) + var shiftFailed bool - for i := numKeep + discard; i < len(slot.Inputs); i++ { + if c.lc.KvCacheCanShift() { + // For models that support shifting, attempt to shift the KV cache + if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { + shiftFailed = true + slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } else { + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard) + } + } else { + // For models that don't support shifting + shiftFailed = true + slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } + + if shiftFailed { + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Clear the entire KV cache + _ = c.lc.KvCacheSeqRm(slot.Id, 0, -1) + // Reset the slot inputs since we've cleared the cache + slot.Inputs = []input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} + } + + // Standard shift succeeded - update input array + for i := numKeep + discard; i < inputLen; i++ { slot.Inputs[i-discard] = slot.Inputs[i] } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] + slot.Inputs = slot.Inputs[:inputLen-discard] return nil } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index ee5d47f6..a4264f5f 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Continue processing as normal + continue + } else { + return err + } } } else { break diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index aa56c982..af48ff22 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { return discard } +type ErrReprocessInputs struct { + Inputs []input.Input +} + +func (e *ErrReprocessInputs) Error() string { + return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs)) +} + // Frees up space in the KV cache by deleting the oldest half of history and shifting // the newest half into that space (saving numKeep inputs at the beginning). // @@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models if c.cache != nil { err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) if err != nil { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing", + "id", slot.Id, "error", err) + + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) + copy(newInputs[:numKeep], slot.Inputs[:numKeep]) + copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) + + // Reset the cache + _ = c.cache.Remove(slot.Id, 0, -1) + slot.Inputs = []input.Input{} + + // Return error with inputs that need to be reprocessed + return &ErrReprocessInputs{Inputs: newInputs} } } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index f8925d11..6a8d8a6a 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -1,10 +1,13 @@ package ollamarunner import ( + "errors" + "fmt" "image" "testing" "time" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/input" ) @@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) { }) } } + +// Mock implementation of the Cache interface +type mockCache struct { + shouldFail bool +} + +// Implement only the methods needed for the test +func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error { + if m.shouldFail { + return fmt.Errorf("mock cache removal error") + } + return nil +} + +// Stub implementations for other interface methods +func (m *mockCache) SetLayer(layer int) {} +func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil } +func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} +func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} +func (m *mockCache) Close() {} +func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } +func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} +func (m *mockCache) SetConfig(ml.CacheConfig) {} + +func TestShiftCacheSlot(t *testing.T) { + tests := []struct { + name string + numCtx int32 + inputs []input.Input + numKeep int32 + cacheErr bool + wantErr any + wantInputsLen int + }{ + { + name: "Normal shift", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: false, // No error + wantErr: nil, + wantInputsLen: 6, // After discarding 4 tokens + }, + { + name: "Cache removal fails", + numCtx: 10, + inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + numKeep: 2, + cacheErr: true, + wantErr: &ErrReprocessInputs{}, + wantInputsLen: 0, // Original inputs should be cleared + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockCache{shouldFail: tt.cacheErr} + c := InputCache{ + numCtx: tt.numCtx, + cache: mock, + } + slot := &InputCacheSlot{ + Id: 123, + Inputs: make([]input.Input, len(tt.inputs)), + } + copy(slot.Inputs, tt.inputs) + + err := c.ShiftCacheSlot(slot, tt.numKeep) + + if tt.wantErr != nil { + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + if !errors.As(err, &tt.wantErr) { + t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(slot.Inputs) != tt.wantInputsLen { + t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index bc7a07ed..45838718 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -407,7 +407,15 @@ func (s *Server) processBatch() error { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { - return err + var reprocess *ErrReprocessInputs + if errors.As(err, &reprocess) { + // Prepend these inputs to the sequence's inputs queue for reprocessing + seq.inputs = append(reprocess.Inputs, seq.inputs...) + // Skip this sequence but continue processing the rest + continue + } else { + return err + } } } From 4059a297a6d95ce94f3619eac0536fda666d58f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B9=9B=E9=9C=B2=E5=85=88=E7=94=9F?= Date: Tue, 1 Apr 2025 08:07:42 +0800 Subject: [PATCH 06/17] discover: /proc/cpuinfo file open and close. (#9950) Signed-off-by: zhanluxianshen --- discover/gpu_linux.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/discover/gpu_linux.go b/discover/gpu_linux.go index d636a54e..44c53b44 100644 --- a/discover/gpu_linux.go +++ b/discover/gpu_linux.go @@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) { if err != nil { return nil, err } + defer file.Close() return linuxCPUDetails(file) } @@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) { for id, s := range socketByID { s.CoreCount = len(coreBySocket[id]) s.ThreadCount = 0 - for _, tc := range threadsByCoreBySocket[id] { - s.ThreadCount += tc - } // This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons? efficiencyCoreCount := 0 for _, threads := range threadsByCoreBySocket[id] { + s.ThreadCount += threads if threads == 1 { efficiencyCoreCount++ } From 23fc8e92eb01ddd1cf06b34ff270926ec7edd4b8 Mon Sep 17 00:00:00 2001 From: Abyss-c0re Date: Tue, 1 Apr 2025 03:23:04 +0300 Subject: [PATCH 07/17] docs: add DeepShell to community projects (#9955) Co-authored-by: Bruce MacDonald --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c3d49105..24391caf 100644 --- a/README.md +++ b/README.md @@ -435,6 +435,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage) - [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more. - [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama +- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal. From c001b98087e45b7b60509127d4d2e9d9ba809444 Mon Sep 17 00:00:00 2001 From: Ilian Date: Tue, 1 Apr 2025 02:28:59 +0200 Subject: [PATCH 08/17] docs: add TagSpaces to community integrations (#9983) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 24391caf..aca00a87 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [Saddle](https://github.com/jikkuatwork/saddle) +- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) From e172f095ba4af2c98d7744ce4ffcf4cd3a8e123c Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 1 Apr 2025 15:21:46 -0700 Subject: [PATCH 09/17] api: return model capabilities from the show endpoint (#10066) With support for multimodal models becoming more varied and common it is important for clients to be able to easily see what capabilities a model has. Retuning these from the show endpoint will allow clients to easily see what a model can do. --- api/types.go | 24 +-- cmd/cmd.go | 15 ++ cmd/cmd_test.go | 29 +++ docs/api.md | 8 +- server/images.go | 101 +++++++---- server/images_test.go | 360 ++++++++++++++++++++++++++++++++++++++ server/routes.go | 35 ++-- server/sched.go | 3 +- types/model/capability.go | 15 ++ 9 files changed, 521 insertions(+), 69 deletions(-) create mode 100644 server/images_test.go create mode 100644 types/model/capability.go diff --git a/api/types.go b/api/types.go index a38b335b..b4a65fe5 100644 --- a/api/types.go +++ b/api/types.go @@ -12,6 +12,7 @@ import ( "time" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/types/model" ) // StatusError is an error with an HTTP status code and message. @@ -340,17 +341,18 @@ type ShowRequest struct { // ShowResponse is the response returned from [Client.Show]. type ShowResponse struct { - License string `json:"license,omitempty"` - Modelfile string `json:"modelfile,omitempty"` - Parameters string `json:"parameters,omitempty"` - Template string `json:"template,omitempty"` - System string `json:"system,omitempty"` - Details ModelDetails `json:"details,omitempty"` - Messages []Message `json:"messages,omitempty"` - ModelInfo map[string]any `json:"model_info,omitempty"` - ProjectorInfo map[string]any `json:"projector_info,omitempty"` - Tensors []Tensor `json:"tensors,omitempty"` - ModifiedAt time.Time `json:"modified_at,omitempty"` + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` + Details ModelDetails `json:"details,omitempty"` + Messages []Message `json:"messages,omitempty"` + ModelInfo map[string]any `json:"model_info,omitempty"` + ProjectorInfo map[string]any `json:"projector_info,omitempty"` + Tensors []Tensor `json:"tensors,omitempty"` + Capabilities []model.Capability `json:"capabilities,omitempty"` + ModifiedAt time.Time `json:"modified_at,omitempty"` } // CopyRequest is the request passed to [Client.Copy]. diff --git a/cmd/cmd.go b/cmd/cmd.go index abb4806b..36d7e6cf 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -18,6 +18,7 @@ import ( "os/signal" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } + opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) + + // TODO: remove the projector info and vision info checks below, + // these are left in for backwards compatibility with older servers + // that don't have the capabilities field in the model info if len(info.ProjectorInfo) != 0 { opts.MultiModal = true } @@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { return }) + if len(resp.Capabilities) > 0 { + tableRender("Capabilities", func() (rows [][]string) { + for _, capability := range resp.Capabilities { + rows = append(rows, []string{"", capability.String()}) + } + return + }) + } + if resp.ProjectorInfo != nil { tableRender("Projector", func() (rows [][]string) { arch := resp.ProjectorInfo["general.architecture"].(string) diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index ea3bdffe..e6a542d0 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" ) func TestShowInfo(t *testing.T) { @@ -260,6 +261,34 @@ Weigh anchor! t.Errorf("unexpected output (-want +got):\n%s", diff) } }) + + t.Run("capabilities", func(t *testing.T) { + var b bytes.Buffer + if err := showInfo(&api.ShowResponse{ + Details: api.ModelDetails{ + Family: "test", + ParameterSize: "7B", + QuantizationLevel: "FP16", + }, + Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools}, + }, false, &b); err != nil { + t.Fatal(err) + } + + expect := " Model\n" + + " architecture test \n" + + " parameters 7B \n" + + " quantization FP16 \n" + + "\n" + + " Capabilities\n" + + " vision \n" + + " tools \n" + + "\n" + + if diff := cmp.Diff(expect, b.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } + }) } func TestDeleteHandler(t *testing.T) { diff --git a/docs/api.md b/docs/api.md index fe044d79..04ee299d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1217,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter ```shell curl http://localhost:11434/api/show -d '{ - "model": "llama3.2" + "model": "llava" }' ``` @@ -1260,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{ "tokenizer.ggml.pre": "llama-bpe", "tokenizer.ggml.token_type": [], // populates if `verbose=true` "tokenizer.ggml.tokens": [] // populates if `verbose=true` - } + }, + "capabilities": [ + "completion", + "vision" + ], } ``` diff --git a/server/images.go b/server/images.go index 290e68ba..2ef9e5d0 100644 --- a/server/images.go +++ b/server/images.go @@ -35,17 +35,11 @@ var ( errCapabilityCompletion = errors.New("completion") errCapabilityTools = errors.New("tools") errCapabilityInsert = errors.New("insert") + errCapabilityVision = errors.New("vision") + errCapabilityEmbedding = errors.New("embedding") errInsecureProtocol = errors.New("insecure protocol http") ) -type Capability string - -const ( - CapabilityCompletion = Capability("completion") - CapabilityTools = Capability("tools") - CapabilityInsert = Capability("insert") -) - type registryOptions struct { Insecure bool Username string @@ -72,46 +66,77 @@ type Model struct { Template *template.Template } +// Capabilities returns the capabilities that the model supports +func (m *Model) Capabilities() []model.Capability { + capabilities := []model.Capability{} + + // Check for completion capability + r, err := os.Open(m.ModelPath) + if err == nil { + defer r.Close() + + f, _, err := ggml.Decode(r, 0) + if err == nil { + if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + capabilities = append(capabilities, model.CapabilityCompletion) + } + if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok { + capabilities = append(capabilities, model.CapabilityVision) + } + } else { + slog.Error("couldn't decode ggml", "error", err) + } + } else { + slog.Error("couldn't open model file", "error", err) + } + + if m.Template == nil { + return capabilities + } + + // Check for tools capability + if slices.Contains(m.Template.Vars(), "tools") { + capabilities = append(capabilities, model.CapabilityTools) + } + + // Check for insert capability + if slices.Contains(m.Template.Vars(), "suffix") { + capabilities = append(capabilities, model.CapabilityInsert) + } + + return capabilities +} + // CheckCapabilities checks if the model has the specified capabilities returning an error describing // any missing or unknown capabilities -func (m *Model) CheckCapabilities(caps ...Capability) error { +func (m *Model) CheckCapabilities(want ...model.Capability) error { + available := m.Capabilities() var errs []error - for _, cap := range caps { - switch cap { - case CapabilityCompletion: - r, err := os.Open(m.ModelPath) - if err != nil { - slog.Error("couldn't open model file", "error", err) - continue - } - defer r.Close() - // TODO(mxyng): decode the GGML into model to avoid doing this multiple times - f, _, err := ggml.Decode(r, 0) - if err != nil { - slog.Error("couldn't decode ggml", "error", err) - continue - } + // Map capabilities to their corresponding error + capToErr := map[model.Capability]error{ + model.CapabilityCompletion: errCapabilityCompletion, + model.CapabilityTools: errCapabilityTools, + model.CapabilityInsert: errCapabilityInsert, + model.CapabilityVision: errCapabilityVision, + model.CapabilityEmbedding: errCapabilityEmbedding, + } - if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { - errs = append(errs, errCapabilityCompletion) - } - case CapabilityTools: - if !slices.Contains(m.Template.Vars(), "tools") { - errs = append(errs, errCapabilityTools) - } - case CapabilityInsert: - vars := m.Template.Vars() - if !slices.Contains(vars, "suffix") { - errs = append(errs, errCapabilityInsert) - } - default: + for _, cap := range want { + err, ok := capToErr[cap] + if !ok { slog.Error("unknown capability", "capability", cap) return fmt.Errorf("unknown capability: %s", cap) } + + if !slices.Contains(available, cap) { + errs = append(errs, err) + } } - if err := errors.Join(errs...); err != nil { + if len(errs) > 0 { return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } diff --git a/server/images_test.go b/server/images_test.go new file mode 100644 index 00000000..22e5b7e6 --- /dev/null +++ b/server/images_test.go @@ -0,0 +1,360 @@ +package server + +import ( + "bytes" + "encoding/binary" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ollama/ollama/template" + "github.com/ollama/ollama/types/model" +) + +// Constants for GGUF magic bytes and version +var ( + ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF" + ggufVer = uint32(3) // Version 3 +) + +// Helper function to create mock GGUF data +func createMockGGUFData(architecture string, vision bool) []byte { + var buf bytes.Buffer + + // Write GGUF header + buf.Write(ggufMagic) + binary.Write(&buf, binary.LittleEndian, ggufVer) + + // Write tensor count (0 for our test) + var numTensors uint64 = 0 + binary.Write(&buf, binary.LittleEndian, numTensors) + + // Calculate number of metadata entries + numMetaEntries := uint64(1) // architecture entry + if vision { + numMetaEntries++ + } + // Add embedding entry if architecture is "bert" + if architecture == "bert" { + numMetaEntries++ + } + binary.Write(&buf, binary.LittleEndian, numMetaEntries) + + // Write architecture metadata + archKey := "general.architecture" + keyLen := uint64(len(archKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(archKey) + + // String type (8) + var strType uint32 = 8 + binary.Write(&buf, binary.LittleEndian, strType) + + // String length + strLen := uint64(len(architecture)) + binary.Write(&buf, binary.LittleEndian, strLen) + buf.WriteString(architecture) + + if vision { + visionKey := architecture + ".vision.block_count" + keyLen = uint64(len(visionKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(visionKey) + + // uint32 type (4) + var uint32Type uint32 = 4 + binary.Write(&buf, binary.LittleEndian, uint32Type) + + // uint32 value (1) + var countVal uint32 = 1 + binary.Write(&buf, binary.LittleEndian, countVal) + } + // Write embedding metadata if architecture is "bert" + if architecture == "bert" { + poolKey := architecture + ".pooling_type" + keyLen = uint64(len(poolKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(poolKey) + + // uint32 type (4) + var uint32Type uint32 = 4 + binary.Write(&buf, binary.LittleEndian, uint32Type) + + // uint32 value (1) + var poolingVal uint32 = 1 + binary.Write(&buf, binary.LittleEndian, poolingVal) + } + + return buf.Bytes() +} + +func TestModelCapabilities(t *testing.T) { + // Create a temporary directory for test files + tempDir, err := os.MkdirTemp("", "model_capabilities_test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create different types of mock model files + completionModelPath := filepath.Join(tempDir, "model.bin") + visionModelPath := filepath.Join(tempDir, "vision_model.bin") + embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + // Create a simple model file for tests that don't depend on GGUF content + simpleModelPath := filepath.Join(tempDir, "simple_model.bin") + + err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644) + if err != nil { + t.Fatalf("Failed to create completion model file: %v", err) + } + err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644) + if err != nil { + t.Fatalf("Failed to create completion model file: %v", err) + } + err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644) + if err != nil { + t.Fatalf("Failed to create embedding model file: %v", err) + } + err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644) + if err != nil { + t.Fatalf("Failed to create simple model file: %v", err) + } + + toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + chatTemplate, err := template.Parse("{{ .prompt }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + testModels := []struct { + name string + model Model + expectedCaps []model.Capability + }{ + { + name: "model with completion capability", + model: Model{ + ModelPath: completionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion}, + }, + + { + name: "model with completion, tools, and insert capability", + model: Model{ + ModelPath: completionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with tools and insert capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with tools capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityTools}, + }, + { + name: "model with vision capability", + model: Model{ + ModelPath: visionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision}, + }, + { + name: "model with vision, tools, and insert capability", + model: Model{ + ModelPath: visionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with embedding capability", + model: Model{ + ModelPath: embeddingModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityEmbedding}, + }, + } + + // compare two slices of model.Capability regardless of order + compareCapabilities := func(a, b []model.Capability) bool { + if len(a) != len(b) { + return false + } + + aCount := make(map[model.Capability]int) + for _, cap := range a { + aCount[cap]++ + } + + bCount := make(map[model.Capability]int) + for _, cap := range b { + bCount[cap]++ + } + + for cap, count := range aCount { + if bCount[cap] != count { + return false + } + } + + return true + } + + for _, tt := range testModels { + t.Run(tt.name, func(t *testing.T) { + // Test Capabilities method + caps := tt.model.Capabilities() + if !compareCapabilities(caps, tt.expectedCaps) { + t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps) + } + }) + } +} + +func TestModelCheckCapabilities(t *testing.T) { + // Create a temporary directory for test files + tempDir, err := os.MkdirTemp("", "model_check_capabilities_test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + visionModelPath := filepath.Join(tempDir, "vision_model.bin") + simpleModelPath := filepath.Join(tempDir, "model.bin") + embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + + err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644) + if err != nil { + t.Fatalf("Failed to create simple model file: %v", err) + } + err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644) + if err != nil { + t.Fatalf("Failed to create vision model file: %v", err) + } + err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644) + if err != nil { + t.Fatalf("Failed to create embedding model file: %v", err) + } + + toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + chatTemplate, err := template.Parse("{{ .prompt }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + tests := []struct { + name string + model Model + checkCaps []model.Capability + expectedErrMsg string + }{ + { + name: "completion model without tools capability", + model: Model{ + ModelPath: simpleModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityTools}, + expectedErrMsg: "does not support tools", + }, + { + name: "model with all needed capabilities", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsInsertTemplate, + }, + checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model missing insert capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + checkCaps: []model.Capability{model.CapabilityInsert}, + expectedErrMsg: "does not support insert", + }, + { + name: "model missing vision capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + checkCaps: []model.Capability{model.CapabilityVision}, + expectedErrMsg: "does not support vision", + }, + { + name: "model with vision capability", + model: Model{ + ModelPath: visionModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityVision}, + }, + { + name: "model with embedding capability", + model: Model{ + ModelPath: embeddingModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityEmbedding}, + }, + { + name: "unknown capability", + model: Model{ + ModelPath: simpleModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{"unknown"}, + expectedErrMsg: "unknown capability", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test CheckCapabilities method + err := tt.model.CheckCapabilities(tt.checkCaps...) + if tt.expectedErrMsg == "" { + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg) + } else if !strings.Contains(err.Error(), tt.expectedErrMsg) { + t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err) + } + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 92336af0..95e49820 100644 --- a/server/routes.go +++ b/server/routes.go @@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options // scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { +func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { return nil, nil, nil, fmt.Errorf("model %w", errRequired) } @@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - model, err := GetModel(name.String()) + m, err := GetModel(name.String()) if err != nil { switch { case errors.Is(err, fs.ErrNotExist): @@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { // expire the runner if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { - s.sched.expireRunner(model) + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - caps := []Capability{CapabilityCompletion} + caps := []model.Capability{model.CapabilityCompletion} if req.Suffix != "" { - caps = append(caps, CapabilityInsert) + caps = append(caps, model.CapabilityInsert) } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) @@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - isMllama := checkMllamaModelFamily(model) + isMllama := checkMllamaModelFamily(m) if isMllama && len(req.Images) > 1 { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"}) return @@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama && len(model.ProjectorPaths) > 0 { + if isMllama && len(m.ProjectorPaths) > 0 { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) @@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -813,12 +813,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } resp := &api.ShowResponse{ - License: strings.Join(m.License, "\n"), - System: m.System, - Template: m.Template.String(), - Details: modelDetails, - Messages: msgs, - ModifiedAt: manifest.fi.ModTime(), + License: strings.Join(m.License, "\n"), + System: m.System, + Template: m.Template.String(), + Details: modelDetails, + Messages: msgs, + Capabilities: m.Capabilities(), + ModifiedAt: manifest.fi.ModTime(), } var params []string @@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - caps := []Capability{CapabilityCompletion} + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { - caps = append(caps, CapabilityTools) + caps = append(caps, model.CapabilityTools) } name := model.ParseName(req.Model) diff --git a/server/sched.go b/server/sched.go index 9126c296..e6cefa5a 100644 --- a/server/sched.go +++ b/server/sched.go @@ -20,6 +20,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/types/model" ) type LlmRequest struct { @@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } // Embedding models should always be loaded with parallel=1 - if pending.model.CheckCapabilities(CapabilityCompletion) != nil { + if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil { numParallel = 1 } diff --git a/types/model/capability.go b/types/model/capability.go new file mode 100644 index 00000000..fb868940 --- /dev/null +++ b/types/model/capability.go @@ -0,0 +1,15 @@ +package model + +type Capability string + +const ( + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") + CapabilityVision = Capability("vision") + CapabilityEmbedding = Capability("embedding") +) + +func (c Capability) String() string { + return string(c) +} From 4e415029b30b2dc8a666491fdbe6254536e5d810 Mon Sep 17 00:00:00 2001 From: IsAurora6 <85173010+IsAurora6@users.noreply.github.com> Date: Wed, 2 Apr 2025 16:27:16 +0800 Subject: [PATCH 10/17] readme: add Casibase to community integrations (#10057) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index aca00a87..e472fbc5 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama) - [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) +- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG) From 9876c9faa41c7dd7143fa47727520d353559f81b Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 2 Apr 2025 09:44:27 -0700 Subject: [PATCH 11/17] chore(all): replace instances of interface with any (#10067) Both interface{} and any (which is just an alias for interface{} introduced in Go 1.18) represent the empty interface that all types satisfy. --- api/types.go | 22 +++++++++---------- api/types_test.go | 2 +- benchmark/server_benchmark_test.go | 4 ++-- cmd/cmd.go | 4 ++-- .../sentencepiece/sentencepiece_model.pb.go | 14 ++++++------ discover/cpu_common.go | 2 +- format/time_test.go | 2 +- integration/basic_test.go | 8 +++---- integration/concurrency_test.go | 4 ++-- integration/context_test.go | 4 ++-- integration/llm_image_test.go | 6 ++--- integration/llm_test.go | 4 ++-- integration/max_queue_test.go | 2 +- integration/utils_test.go | 10 ++++----- openai/openai.go | 10 ++++----- openai/openai_test.go | 2 +- server/images.go | 2 +- server/routes.go | 6 ++--- server/sched.go | 8 +++---- 19 files changed, 58 insertions(+), 58 deletions(-) diff --git a/api/types.go b/api/types.go index b4a65fe5..a70fb120 100644 --- a/api/types.go +++ b/api/types.go @@ -82,7 +82,7 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. - Options map[string]interface{} `json:"options"` + Options map[string]any `json:"options"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -107,7 +107,7 @@ type ChatRequest struct { Tools `json:"tools,omitempty"` // Options lists model-specific options. - Options map[string]interface{} `json:"options"` + Options map[string]any `json:"options"` } type Tools []Tool @@ -261,7 +261,7 @@ type EmbedRequest struct { Truncate *bool `json:"truncate,omitempty"` // Options lists model-specific options. - Options map[string]interface{} `json:"options"` + Options map[string]any `json:"options"` } // EmbedResponse is the response from [Client.Embed]. @@ -287,7 +287,7 @@ type EmbeddingRequest struct { KeepAlive *Duration `json:"keep_alive,omitempty"` // Options lists model-specific options. - Options map[string]interface{} `json:"options"` + Options map[string]any `json:"options"` } // EmbeddingResponse is the response from [Client.Embeddings]. @@ -333,7 +333,7 @@ type ShowRequest struct { Template string `json:"template"` Verbose bool `json:"verbose"` - Options map[string]interface{} `json:"options"` + Options map[string]any `json:"options"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -505,7 +505,7 @@ func (m *Metrics) Summary() { } } -func (opts *Options) FromMap(m map[string]interface{}) error { +func (opts *Options) FromMap(m map[string]any) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct @@ -562,12 +562,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error { } field.SetString(val) case reflect.Slice: - // JSON unmarshals to []interface{}, not []string - val, ok := val.([]interface{}) + // JSON unmarshals to []any, not []string + val, ok := val.([]any) if !ok { return fmt.Errorf("option %q must be of type array", key) } - // convert []interface{} to []string + // convert []any to []string slice := make([]string, len(val)) for i, item := range val { str, ok := item.(string) @@ -674,7 +674,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { } // FormatParams converts specified parameter options to their correct types -func FormatParams(params map[string][]string) (map[string]interface{}, error) { +func FormatParams(params map[string][]string) (map[string]any, error) { opts := Options{} valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct @@ -688,7 +688,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) { } } - out := make(map[string]interface{}) + out := make(map[string]any) // iterate params and set values based on json struct tags for key, vals := range params { if opt, ok := jsonOpts[key]; !ok { diff --git a/api/types_test.go b/api/types_test.go index a9de5a9a..b28d4249 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var oMap map[string]interface{} + var oMap map[string]any err := json.Unmarshal([]byte(test.req), &oMap) require.NoError(t, err) opts := DefaultOptions() diff --git a/benchmark/server_benchmark_test.go b/benchmark/server_benchmark_test.go index b27aa630..672b8b17 100644 --- a/benchmark/server_benchmark_test.go +++ b/benchmark/server_benchmark_test.go @@ -92,7 +92,7 @@ func BenchmarkColdStart(b *testing.B) { req := &api.GenerateRequest{ Model: m, Prompt: tt.prompt, - Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1}, + Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1}, } runGenerateBenchmark(b, ctx, client, req) @@ -155,7 +155,7 @@ func warmup(client *api.Client, model string, prompt string, b *testing.B) { &api.GenerateRequest{ Model: model, Prompt: prompt, - Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1}, + Options: map[string]any{"num_predict": 50, "temperature": 0.1}, }, func(api.GenerateResponse) error { return nil }, ) diff --git a/cmd/cmd.go b/cmd/cmd.go index 36d7e6cf..84727862 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -268,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts := runOptions{ Model: args[0], WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]interface{}{}, + Options: map[string]any{}, } format, err := cmd.Flags().GetString("format") @@ -852,7 +852,7 @@ type runOptions struct { Format string System string Images []api.ImageData - Options map[string]interface{} + Options map[string]any MultiModal bool KeepAlive *api.Duration } diff --git a/convert/sentencepiece/sentencepiece_model.pb.go b/convert/sentencepiece/sentencepiece_model.pb.go index 6bf66891..76d136e8 100644 --- a/convert/sentencepiece/sentencepiece_model.pb.go +++ b/convert/sentencepiece/sentencepiece_model.pb.go @@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte { var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2) var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6) -var file_sentencepiece_model_proto_goTypes = []interface{}{ +var file_sentencepiece_model_proto_goTypes = []any{ (TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType (ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type (*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec @@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*TrainerSpec); i { case 0: return &v.state @@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() { return nil } } - file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*NormalizerSpec); i { case 0: return &v.state @@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() { return nil } } - file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*SelfTestData); i { case 0: return &v.state @@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() { return nil } } - file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*ModelProto); i { case 0: return &v.state @@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() { return nil } } - file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*SelfTestData_Sample); i { case 0: return &v.state @@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() { return nil } } - file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any { switch v := v.(*ModelProto_SentencePiece); i { case 0: return &v.state diff --git a/discover/cpu_common.go b/discover/cpu_common.go index 242e4879..2b9f7292 100644 --- a/discover/cpu_common.go +++ b/discover/cpu_common.go @@ -12,7 +12,7 @@ func IsNUMA() bool { // numa support in llama.cpp is linux only return false } - ids := map[string]interface{}{} + ids := map[string]any{} packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") for _, packageId := range packageIds { id, err := os.ReadFile(packageId) diff --git a/format/time_test.go b/format/time_test.go index bd0ba9a8..d0f8934d 100644 --- a/format/time_test.go +++ b/format/time_test.go @@ -5,7 +5,7 @@ import ( "time" ) -func assertEqual(t *testing.T, a interface{}, b interface{}) { +func assertEqual(t *testing.T, a any, b any) { if a != b { t.Errorf("Assert failed, expected %v, got %v", b, a) } diff --git a/integration/basic_test.go b/integration/basic_test.go index 88d3530e..09e62259 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) { Model: "orca-mini", Prompt: "why is the sky blue?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, }, @@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) { Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", Prompt: "天空为什么是蓝色的?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, // Workaround deepseek context shifting bug @@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) { Model: "gemma2:2b", Prompt: "Output some smily face emoji", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, }, @@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) { Model: "orca-mini", Prompt: "why is the sky blue?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, }, diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 78e3b5ab..5f7f289e 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) { Prompt: "why is the ocean blue?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) { Prompt: "what is the origin of the us thanksgiving holiday?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, diff --git a/integration/context_test.go b/integration/context_test.go index add41a76..409d913a 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) { Model: "llama2", Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, "num_ctx": 128, @@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) { Model: "llama2", Prompt: "Write me a story with a ton of emojis?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "temperature": 0, "seed": 123, "num_ctx": 128, diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index fbbd9d5c..51a16fc7 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) { Model: "llava:7b", Prompt: "what does the text in this image say?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) { Model: "x/llama3.2-vision", Prompt: "what does the text in this image say?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -75,7 +75,7 @@ func TestIntegrationSplitBatch(t *testing.T) { System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.", Prompt: "what does the text in this image say?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, diff --git a/integration/llm_test.go b/integration/llm_test.go index 398e0a03..f897fdd9 100644 --- a/integration/llm_test.go +++ b/integration/llm_test.go @@ -20,7 +20,7 @@ var ( Model: "orca-mini", Prompt: "why is the ocean blue?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -28,7 +28,7 @@ var ( Model: "orca-mini", Prompt: "what is the origin of the us thanksgiving holiday?", Stream: &stream, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index 1878d0da..c316aa62 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) { req := api.GenerateRequest{ Model: "orca-mini", Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey", - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, diff --git a/integration/utils_test.go b/integration/utils_test.go index e76b63f2..71304cd0 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { Prompt: "why is the ocean blue?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { Prompt: "why is the color of dirt brown?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { Prompt: "what is the origin of the us thanksgiving holiday?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { Prompt: "what is the origin of independence day?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, @@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { Prompt: "what is the composition of air?", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, - Options: map[string]interface{}{ + Options: map[string]any{ "seed": 42, "temperature": 0.0, }, diff --git a/openai/openai.go b/openai/openai.go index 214801fa..012189d2 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -23,10 +23,10 @@ import ( var finishReasonToolCalls = "tool_calls" type Error struct { - Message string `json:"message"` - Type string `json:"type"` - Param interface{} `json:"param"` - Code *string `json:"code"` + Message string `json:"message"` + Type string `json:"type"` + Param any `json:"param"` + Code *string `json:"code"` } type ErrorResponse struct { @@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } } - options := make(map[string]interface{}) + options := make(map[string]any) switch stop := r.Stop.(type) { case string: diff --git a/openai/openai_test.go b/openai/openai_test.go index d8c821d3..a6acfcac 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]interface{}{ + Arguments: map[string]any{ "location": "Paris, France", "format": "celsius", }, diff --git a/server/images.go b/server/images.go index 2ef9e5d0..bd6d92a6 100644 --- a/server/images.go +++ b/server/images.go @@ -60,7 +60,7 @@ type Model struct { System string License []string Digest string - Options map[string]interface{} + Options map[string]any Messages []api.Message Template *template.Template diff --git a/server/routes.go b/server/routes.go index 95e49820..eee34033 100644 --- a/server/routes.go +++ b/server/routes.go @@ -72,7 +72,7 @@ var ( errBadTemplate = errors.New("template error") ) -func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { +func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { return api.Options{}, err @@ -826,7 +826,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { cs := 30 for k, v := range m.Options { switch val := v.(type) { - case []interface{}: + case []any: for _, nv := range val { params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv)) } @@ -1336,7 +1336,7 @@ func Serve(ln net.Listener) error { return nil } -func waitForStream(c *gin.Context, ch chan interface{}) { +func waitForStream(c *gin.Context, ch chan any) { c.Header("Content-Type", "application/json") for resp := range ch { switch r := resp.(type) { diff --git a/server/sched.go b/server/sched.go index e6cefa5a..8082680b 100644 --- a/server/sched.go +++ b/server/sched.go @@ -38,7 +38,7 @@ type Scheduler struct { pendingReqCh chan *LlmRequest finishedReqCh chan *LlmRequest expiredCh chan *runnerRef - unloadedCh chan interface{} + unloadedCh chan any loaded map[string]*runnerRef loadedMu sync.Mutex @@ -68,7 +68,7 @@ func InitScheduler(ctx context.Context) *Scheduler { pendingReqCh: make(chan *LlmRequest, maxQueue), finishedReqCh: make(chan *LlmRequest, maxQueue), expiredCh: make(chan *runnerRef, maxQueue), - unloadedCh: make(chan interface{}, maxQueue), + unloadedCh: make(chan any, maxQueue), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: discover.GetGPUInfo, @@ -618,8 +618,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool // a before and after GPU memory allocation. The returned channel // will be notified when we're done waiting, or have timed out and should // proceed anyway -func (runner *runnerRef) waitForVRAMRecovery() chan interface{} { - finished := make(chan interface{}, 1) +func (runner *runnerRef) waitForVRAMRecovery() chan any { + finished := make(chan any, 1) // CPU or Metal don't need checking, so no waiting required // windows can page VRAM, only cuda currently can report accurate used vram usage From 493385eb3e811ebbb49c6a23d6db7c39885bbb89 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 1 Apr 2025 15:01:23 -0700 Subject: [PATCH 12/17] ollamarunner: Don't truncate a SameBatch When truncating inputs to the the context window at the beginning of a sequence, we remove the minimum amount possible. However, this may cause us to truncate to the middle of a set of inputs that the model specified should not be split up. To avoid this, we need to remove the rest of the partial batch. --- runner/ollamarunner/cache.go | 2 ++ runner/ollamarunner/runner.go | 33 +++++++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index af48ff22..30292f64 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -225,6 +225,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 { return count } +// TODO(jessegross): If we need to reprocess the inputs we should ensure that +// we don't split up a SameBatch func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { targetFree := (c.numCtx - numKeep) / 2 targetFree = max(targetFree, 1) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 45838718..f3286aba 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -115,16 +115,41 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } - // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, - // otherwise we might truncate or split the batch against the model's wishes - // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) if int32(len(inputs)) > s.cache.numCtx { discard := int32(len(inputs)) - s.cache.numCtx + promptStart := params.numKeep + discard + + // If we need to truncate in the middle of a unbreakable batch, remove the entire batch + sameBatch := 0 + for i, inp := range inputs { + if sameBatch > 0 { + sameBatch-- + + if promptStart == int32(i) { + promptStart++ + } + } else if promptStart == int32(i) { + break + } + + if inp.SameBatch != 0 { + if int32(i) < params.numKeep { + return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch) + } + + sameBatch = inp.SameBatch + } + } + + if promptStart >= int32(len(inputs)) { + return nil, errors.New("entire prompt removed by truncation") + } + newInputs := inputs[:params.numKeep] - newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + newInputs = append(newInputs, inputs[promptStart:]...) slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) inputs = newInputs From b42970063d8f05c47dd6d9a6b71f1e14cc4805c9 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 30 Mar 2025 16:05:40 -0700 Subject: [PATCH 13/17] kvcache: Add check for values that fall out of sliding window cache The sliding window cache trims entries that are outside the window for the latest token. This works when we are extending the cache, such as when the conversation continues. However, if we have a partial overlap in conversation (including the BOS tokens), then we resume from a past point in the conversation and the needed tokens are no longer stored in memory. This verifies that the new window overlaps with the old one before reusing the cache. Co-authored-by: Jesse Gross --- kvcache/cache.go | 5 +++ kvcache/causal.go | 38 ++++++++++++++++- kvcache/causal_test.go | 71 +++++++++++++++++++++++++++++++ kvcache/encoder.go | 4 ++ kvcache/wrapper.go | 10 +++++ runner/ollamarunner/cache.go | 4 ++ runner/ollamarunner/cache_test.go | 1 + 7 files changed, 131 insertions(+), 2 deletions(-) diff --git a/kvcache/cache.go b/kvcache/cache.go index 18aec800..07015b9e 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -62,6 +62,11 @@ type Cache interface { // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) + // CanResume returns true if the cache can continue with the next token at + // the given position and sequence. Assumes that the caller has already + // verified the contents of the cache. + CanResume(seq int, pos int32) bool + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set // endIndex to math.MaxInt32 to remove everything starting at beginIndex. // diff --git a/kvcache/causal.go b/kvcache/causal.go index fb4f0f74..4fc18d88 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { c.cellRanges[dstSeq] = seqRange } +func (c *Causal) CanResume(seq int, pos int32) bool { + if c.windowSize == math.MaxInt32 { + return true + } + + seqRange, ok := c.cellRanges[seq] + if !ok { + return false + } + + // for sliding window, check that the window of the new sequence is contained in + // the window of what we are storing + var last int32 = -1 + for i := seqRange.min; i <= seqRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + last = max(last, c.cells[i].pos) + } + } + + if last == -1 { + return false + } + + lastWindowStart := max(0, last-c.windowSize) + posWindowStart := max(0, pos-c.windowSize) + + return posWindowStart >= lastWindowStart +} + func (c *Causal) shift(seq int, beginIndex, offset int32) error { if c.shiftFn == nil { return ErrNotSupported @@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + // TODO(jessegross): We should check to see if removing the middle of the sequence will + // cause the sliding window to encompass tokens that we no longer have. If so, then we + // should return an error, which will trigger the runner to evaluate the full history and + // rebuild the window. However, if we have multimodal inputs in our history, this reuse + // results in use after free, so we don't do it for now. + var offset int32 if endIndex != math.MaxInt32 { offset = beginIndex - endIndex @@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { } else { if c.cells[i].pos >= endIndex { if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { - // TODO(jessegross): Need to be careful about data shared between sequences - return errors.New("shifting on cells shared by multiple sequences not yet implemented") + return errors.New("shifting cells shared by multiple sequences not supported") } c.cells[i].pos += offset diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index b1dc7d77..bf98abef 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -300,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } } +func TestCanResume(t *testing.T) { + backend := &testBackend{} + windowSize := int32(4) + cache := NewSWACache(windowSize, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3}, + Sequences: []int{0, 0, 0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + cache.Put(context, tensor, tensor) + + // with window size 4, nothing has slid out of the window yet + if !cache.CanResume(0, 0) { + t.Errorf("CanResume(0, 0) = false, want true (within window)") + } + if !cache.CanResume(0, 1) { + t.Errorf("CanResume(0, 1) = false, want true (within window)") + } + if !cache.CanResume(0, 2) { + t.Errorf("CanResume(0, 2) = false, want true (within window)") + } + if !cache.CanResume(0, 3) { + t.Errorf("CanResume(0, 3) = false, want true (latest position)") + } + + // shift window by adding position 4 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{4, 5}, + Sequences: []int{0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + cache.Put(context, tensor, tensor) + + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if !cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") + } +} + type testBackend struct{} func (b *testBackend) Config() ml.Config { diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 07ff4291..03d650a3 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -134,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { panic("encoder cache does not support multiple sequences") } +func (c *EncoderCache) CanResume(seq int, pos int32) bool { + return true +} + func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { if c.encoderPos >= beginIndex && c.encoderPos < endIndex { c.encoderCached = false diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 0e8ff1f3..926bc2d4 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { } } +func (c *WrapperCache) CanResume(seq int, pos int32) bool { + for _, cache := range c.caches { + if !cache.CanResume(seq, pos) { + return false + } + } + + return true +} + func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail for _, cache := range c.caches { diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 30292f64..01f435e4 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -118,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp } if c.cache != nil { + if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) { + numPast = 0 + } + err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) if err != nil { // Some models don't support partial erasure diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 6a8d8a6a..543b4b2f 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -451,6 +451,7 @@ func (m *mockCache) Close() func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) SetConfig(ml.CacheConfig) {} +func (m *mockCache) CanResume(seq int, pos int32) bool { return true } func TestShiftCacheSlot(t *testing.T) { tests := []struct { From b51e0f397ced70bbfa7f22e9b3c94953967cb8e5 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 2 Apr 2025 13:22:56 -0700 Subject: [PATCH 14/17] model: fix issues with spm tokenizer for Gemma 3 (#10081) --- model/models/gemma2/model.go | 1 - model/models/gemma3/model.go | 1 - model/models/gemma3/model_text.go | 1 - model/process_text_spm.go | 225 ++++++++++++++++-------------- model/process_text_spm_test.go | 60 +++++++- 5 files changed, 175 insertions(+), 113 deletions(-) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 67c69ee8..b8f5f066 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -38,7 +38,6 @@ const ( func New(c ml.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 567ad1a4..f9c53343 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -55,7 +55,6 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c ml.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7d8b6577..7b2b83c0 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -45,7 +45,6 @@ func newTextModel(c ml.Config) *TextModel { m := TextModel{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/process_text_spm.go b/model/process_text_spm.go index 68e3ed01..c6e08dbd 100644 --- a/model/process_text_spm.go +++ b/model/process_text_spm.go @@ -1,29 +1,23 @@ package model import ( - "iter" + "container/heap" + "fmt" "log/slog" + "strconv" "strings" - - "github.com/dlclark/regexp2" - queue "github.com/emirpasic/gods/v2/queues/priorityqueue" ) const spmWhitespaceSep = "▁" -func replaceWhitespaceBySeperator(s string) string { - return strings.ReplaceAll(s, " ", spmWhitespaceSep) -} - type SentencePieceModel struct { maxTokenLen int - pre *regexp2.Regexp vocab *Vocabulary } var _ TextProcessor = (*SentencePieceModel)(nil) -func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { +func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { return SentencePieceModel{ maxTokenLen: maxTokenLen, - pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), vocab: vocab, } } @@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm *SentencePieceModel) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) { - if !yield(m.String()) { - break - } - } - } -} - func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { - // TODO: process special tokens concurrently id := spm.vocab.Encode(special) for i := 0; i < len(fragments); i++ { frag := fragments[i] @@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) } } - slog.Debug("fragments", "frags", fragments) var ids []int32 for _, frag := range fragments { @@ -100,105 +81,96 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) continue } - for split := range spm.split(frag.value) { - split = replaceWhitespaceBySeperator(split) + text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep) - var sb strings.Builder - sb.Write([]byte(split)) - if id := spm.vocab.Encode(sb.String()); id >= 0 { - ids = append(ids, id) - continue + if id := spm.vocab.Encode(text); id >= 0 { + ids = append(ids, id) + continue + } + + q := &queue{} + heap.Init(q) + + runes := []rune(text) + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, } + } - runes := []rune(sb.String()) - pq := queue.NewWith(func(a, b any) int { - priA := a.(*candidate) - priB := b.(*candidate) - if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) { - return -1 - } - return 1 - }) - - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } - } - - slog.Debug("tokenizer", "merges", merges) - - pairwise := func(a, b int) *candidate { - if a < 0 || b >= len(runes) { - return nil - } - - left, right := string(merges[a].runes), string(merges[b].runes) - if id := spm.vocab.Encode(left + right); id >= 0 { - return &candidate{ - a: a, - b: b, - score: spm.vocab.Scores[id], - } - } + pairwise := func(a, b int) *candidate { + if a < 0 || b >= len(runes) { return nil } - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { - pq.Enqueue(pair) + left, right := string(merges[a].runes), string(merges[b].runes) + if id := spm.vocab.Encode(left + right); id >= 0 { + return &candidate{ + a: a, + b: b, + score: spm.vocab.Scores[id], + size: len(left) + len(right), } } - pqv := pq.Values() - for _, v := range pqv { - e := v.(*candidate) - slog.Debug("candidate", "candidate", e) + return nil + } + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + heap.Push(q, pair) + } + } + + for q.Len() > 0 { + pair := heap.Pop(q).(*candidate) + left, right := merges[pair.a], merges[pair.b] + + if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { + continue } - for !pq.Empty() { - v, _ := pq.Dequeue() - pair := v.(*candidate) - left, right := merges[pair.a], merges[pair.b] + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } - slog.Debug("pair", "left", left, "right", right) - if len(left.runes) == 0 || len(right.runes) == 0 { + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + heap.Push(q, pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + heap.Push(q, pair) + } + } + + for _, merge := range merges { + if token := string(merge.runes); token != "" { + id := spm.vocab.Encode(token) + + if id >= 0 { + ids = append(ids, id) continue } - if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 { - continue - } - - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a - } - - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { - pq.Enqueue(pair) - } - - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { - pq.Enqueue(pair) - } - } - - slog.Debug("merges", "merges", merges) - - for _, merge := range merges { - if len(merge.runes) > 0 { - if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { - ids = append(ids, id) + // Fallback to byte tokenization + var result []int32 + for _, b := range []byte(token) { + byteToken := fmt.Sprintf("<0x%02X>", b) + unknownID := spm.vocab.Encode(byteToken) + if unknownID >= 0 { + result = append(result, unknownID) } else { - slog.Debug("missing token", "token", string(merge.runes)) + slog.Debug("unknown byte token", "byte", b, "token", byteToken) } } + + ids = append(ids, result...) } } } @@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) type candidate struct { a, b int score float32 + size int +} + +type queue []*candidate + +func (q queue) Len() int { return len(q) } + +func (q queue) Less(i, j int) bool { + return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a) +} + +func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } + +func (q *queue) Push(x interface{}) { + item := x.(*candidate) + *q = append(*q, item) +} + +func (q *queue) Pop() interface{} { + old := *q + n := len(old) + item := old[n-1] + *q = old[0 : n-1] + return item } func (spm SentencePieceModel) Decode(ids []int32) (string, error) { @@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { for _, id := range ids { data := spm.vocab.Decode(id) data = strings.ReplaceAll(data, spmWhitespaceSep, " ") - if _, err := sb.WriteString(data); err != nil { - return "", err + + // For tokenizers that use byte tokens like "<0xEA>" + // convert them to the partial unicode character + // so they are buffered correctly by the runner instead + // of being sent back to the api as "<0xEA>" + if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") { + byteVal, err := strconv.ParseUint(data[1:5], 0, 8) + if err != nil { + return "", fmt.Errorf("failed to parse hex byte: %v", err) + } + + if err := sb.WriteByte(byte(byteVal)); err != nil { + return "", err + } + } else { + if _, err := sb.WriteString(data); err != nil { + return "", err + } } } - slog.Debug("decoded", "ids", ids, "text", sb.String()) return sb.String(), nil } diff --git a/model/process_text_spm_test.go b/model/process_text_spm_test.go index a43004db..4813333e 100644 --- a/model/process_text_spm_test.go +++ b/model/process_text_spm_test.go @@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { t.Fatal(err) } - preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+` - var v Vocabulary for _, piece := range spm.GetPieces() { @@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(preTokenizer, &v) + return NewSentencePieceModel(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) { } }) } + +func TestSentencePieceModelDecodeByteTokens(t *testing.T) { + vocab := &Vocabulary{ + Values: []string{ + "normal", + "<0xEA>", + "<0x41>", + "<0xC3>", + "<0xA3>", + }, + Types: []uint32{ + TOKEN_TYPE_NORMAL, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + }, + Scores: []float32{0, 0, 0, 0, 0}, + } + + spm := NewSentencePieceModel(vocab) + + tests := []struct { + name string + ids []int32 + expected string + }{ + { + name: "single byte token", + ids: []int32{1}, + expected: "\xea", + }, + { + name: "ASCII byte token", + ids: []int32{2}, + expected: "A", + }, + { + name: "multiple byte tokens forming UTF-8 character", + ids: []int32{3, 4}, + expected: "ã", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := spm.Decode(tt.ids) + if err != nil { + t.Errorf("failed to decode token IDs %v: %v", tt.ids, err) + } + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} From e53b3cbd0c3f08eb692a318c8eaf687a01c2e8c0 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 3 Apr 2025 10:19:24 -0700 Subject: [PATCH 15/17] llm: set done reason at server level (#9830) No functional change. Many different done reasons can be set at the runner level, so rather than obsuring them we should return them to the server process and let it choose what to do with the done reason. This separates the API concerns from the runner. --- llm/server.go | 26 ++++++++++++++++++++++++-- runner/llamarunner/runner.go | 21 ++++++++------------- runner/ollamarunner/runner.go | 21 ++++++++------------- server/routes.go | 20 ++++++++++---------- server/routes_generate_test.go | 8 ++++---- 5 files changed, 54 insertions(+), 42 deletions(-) diff --git a/llm/server.go b/llm/server.go index e6046db6..a2bc1548 100644 --- a/llm/server.go +++ b/llm/server.go @@ -675,9 +675,32 @@ type CompletionRequest struct { Grammar string // set before sending the request to the subprocess } +// DoneReason represents the reason why a completion response is done +type DoneReason int + +const ( + // DoneReasonStop indicates the completion stopped naturally + DoneReasonStop DoneReason = iota + // DoneReasonLength indicates the completion stopped due to length limits + DoneReasonLength + // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed + DoneReasonConnectionClosed +) + +func (d DoneReason) String() string { + switch d { + case DoneReasonLength: + return "length" + case DoneReasonStop: + return "stop" + default: + return "" // closed + } +} + type CompletionResponse struct { Content string `json:"content"` - DoneReason string `json:"done_reason"` + DoneReason DoneReason `json:"done_reason"` Done bool `json:"done"` PromptEvalCount int `json:"prompt_eval_count"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"` @@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a4264f5f..d8169be4 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -83,7 +83,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.embedding <- embed - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numDecoded, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index f3286aba..7b7e0940 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -82,7 +82,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -391,7 +391,7 @@ func (s *Server) processBatch() error { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -510,7 +510,7 @@ func (s *Server) processBatch() error { if seq.embeddingOnly { // TODO(jessegross): Embedding support slog.Warn("generation of embedding outputs not yet supported") - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -528,7 +528,7 @@ func (s *Server) processBatch() error { // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -564,7 +564,7 @@ func (s *Server) processBatch() error { } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -577,7 +577,7 @@ func (s *Server) processBatch() error { } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numPredicted, diff --git a/server/routes.go b/server/routes.go index eee34033..906426b1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Response: cr.Content, - Done: cr.Done, - DoneReason: cr.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Response: cr.Content, + Done: cr.Done, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) { Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - DoneReason: r.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: r.Content}, + Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, @@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { + res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index aa263bf9..f219387c 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) { mock.CompletionResponse = llm.CompletionResponse{ Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, Done: true, - DoneReason: "done", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) { { Content: `, WA","unit":"celsius"}}`, Done: true, - DoneReason: "tool_call", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 3, PromptEvalDuration: 1, }, @@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, From 3b96a93672377129f2a2aafc447e79ef1ca48c5f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Mar 2025 14:38:44 -0700 Subject: [PATCH 16/17] fs: move ml.Config to fs package --- fs/config.go | 13 +++++++++++++ kvcache/causal_test.go | 3 ++- ml/backend.go | 16 +++------------- ml/backend/ggml/ggml.go | 11 ++++++----- model/model.go | 11 ++++++----- model/model_test.go | 9 +++++---- model/models/gemma2/model.go | 3 ++- model/models/gemma3/model.go | 3 ++- model/models/gemma3/model_text.go | 3 ++- model/models/gemma3/model_vision.go | 3 ++- model/models/gemma3/process_image.go | 4 ++-- model/models/llama/model.go | 3 ++- model/models/mllama/model.go | 3 ++- model/models/mllama/model_text.go | 3 ++- model/models/mllama/model_vision.go | 3 ++- model/models/mllama/process_image.go | 4 ++-- 16 files changed, 55 insertions(+), 40 deletions(-) create mode 100644 fs/config.go diff --git a/fs/config.go b/fs/config.go new file mode 100644 index 00000000..bc5bfa55 --- /dev/null +++ b/fs/config.go @@ -0,0 +1,13 @@ +package fs + +type Config interface { + Architecture() string + String(string, ...string) string + Uint(string, ...uint32) uint32 + Float(string, ...float32) float32 + Bool(string, ...bool) bool + + Strings(string, ...[]string) []string + Uints(string, ...[]uint32) []uint32 + Floats(string, ...[]float32) []float32 +} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bf98abef..517e3726 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -5,6 +5,7 @@ import ( "slices" "testing" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/input" ) @@ -373,7 +374,7 @@ func TestCanResume(t *testing.T) { type testBackend struct{} -func (b *testBackend) Config() ml.Config { +func (b *testBackend) Config() fs.Config { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index cfb18d6a..b22ba795 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -9,22 +9,12 @@ import ( "slices" "strconv" "strings" + + "github.com/ollama/ollama/fs" ) -type Config interface { - Architecture() string - String(string, ...string) string - Uint(string, ...uint32) uint32 - Float(string, ...float32) float32 - Bool(string, ...bool) bool - - Strings(string, ...[]string) []string - Uints(string, ...[]uint32) []uint32 - Floats(string, ...[]float32) []float32 -} - type Backend interface { - Config() Config + Config() fs.Config Get(name string) Tensor NewContext() Context NewContextSize(size int) Context diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index b6f59ae0..17f06384 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -24,7 +24,8 @@ import ( "unsafe" "github.com/ollama/ollama/format" - fs "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs" + fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" "golang.org/x/sync/errgroup" @@ -41,7 +42,7 @@ func devices() []*C.struct_ggml_backend_device { } type Backend struct { - meta *fs.GGML + meta *fsggml.GGML sched *C.struct_ggml_backend_sched tensors map[string]*C.struct_ggml_tensor @@ -58,7 +59,7 @@ type Backend struct { } func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { - meta, n, err := fs.Decode(r, -1) + meta, n, err := fsggml.Decode(r, -1) if err != nil { return nil, err } @@ -182,7 +183,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, maxTensors += blocks * 2 type tensor struct { - source *fs.Tensor + source *fsggml.Tensor target string } @@ -413,7 +414,7 @@ func init() { ml.RegisterBackend("ggml", New) } -func (b *Backend) Config() ml.Config { +func (b *Backend) Config() fs.Config { return b.meta.KV() } diff --git a/model/model.go b/model/model.go index 8355a55a..bc8944d2 100644 --- a/model/model.go +++ b/model/model.go @@ -16,7 +16,8 @@ import ( _ "golang.org/x/image/tiff" _ "golang.org/x/image/webp" - fs "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs" + fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" @@ -83,10 +84,10 @@ func (m *Base) Config() config { return m.config } -var models = make(map[string]func(ml.Config) (Model, error)) +var models = make(map[string]func(fs.Config) (Model, error)) // Register registers a model constructor for the given architecture -func Register(name string, f func(ml.Config) (Model, error)) { +func Register(name string, f func(fs.Config) (Model, error)) { if _, ok := models[name]; ok { panic("model: model already registered") } @@ -131,14 +132,14 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() - meta, _, err := fs.Decode(r, -1) + meta, _, err := fsggml.Decode(r, -1) if err != nil { return nil, err } return getTextProcessor(meta.KV()) } -func getTextProcessor(kv fs.KV) (TextProcessor, error) { +func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { arch := kv.Architecture() f, ok := models[arch] if !ok { diff --git a/model/model_test.go b/model/model_test.go index 0b1ea08e..717c425e 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -7,7 +7,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" - fs "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs" + fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" @@ -139,7 +140,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } func TestGetTextProcessor(t *testing.T) { - tp, err := getTextProcessor(fs.KV{}) + tp, err := getTextProcessor(fsggml.KV{}) if err == nil { t.Error("expected error") } else if !strings.Contains(err.Error(), "unsupported model architecture") { @@ -148,10 +149,10 @@ func TestGetTextProcessor(t *testing.T) { t.Error("expected nil tp") } - models["dummy"] = func(ml.Config) (Model, error) { + models["dummy"] = func(fs.Config) (Model, error) { return notTextProcessorModel{}, nil } - tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"}) + tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"}) if err == nil { t.Error("expected error") } else if !strings.Contains(err.Error(), "not a TextProcessor") { diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index b8f5f066..752cb5cc 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -3,6 +3,7 @@ package gemma2 import ( "math" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -35,7 +36,7 @@ const ( gemma27BLayerCount = 46 ) -func New(c ml.Config) (model.Model, error) { +func New(c fs.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( &model.Vocabulary{ diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index f9c53343..cef058e2 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -6,6 +6,7 @@ import ( "math" "slices" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -52,7 +53,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i return visionOutputs } -func New(c ml.Config) (model.Model, error) { +func New(c fs.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( &model.Vocabulary{ diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7b2b83c0..3b640a96 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -3,6 +3,7 @@ package gemma3 import ( "math" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -40,7 +41,7 @@ const ( cacheTypeCausal ) -func newTextModel(c ml.Config) *TextModel { +func newTextModel(c fs.Config) *TextModel { numBlocks := int(c.Uint("block_count")) m := TextModel{ diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 94aa27bd..636a363d 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -3,6 +3,7 @@ package gemma3 import ( "math" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" ) @@ -111,7 +112,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { return hiddenState } -func newVisionModel(c ml.Config) *VisionModel { +func newVisionModel(c fs.Config) *VisionModel { return &VisionModel{ Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), VisionModelOptions: &VisionModelOptions{ diff --git a/model/models/gemma3/process_image.go b/model/models/gemma3/process_image.go index fe8269a3..611a17bd 100644 --- a/model/models/gemma3/process_image.go +++ b/model/models/gemma3/process_image.go @@ -3,7 +3,7 @@ package gemma3 import ( "image" - "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/model/imageproc" ) @@ -11,7 +11,7 @@ type ImageProcessor struct { imageSize, patchSize, numChannels int } -func newImageProcessor(c ml.Config) ImageProcessor { +func newImageProcessor(c fs.Config) ImageProcessor { return ImageProcessor{ imageSize: int(c.Uint("vision.image_size")), patchSize: int(c.Uint("vision.patch_size")), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 5c173997..68980dd7 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -5,6 +5,7 @@ import ( "math" "strings" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -30,7 +31,7 @@ type Model struct { *Options } -func New(c ml.Config) (model.Model, error) { +func New(c fs.Config) (model.Model, error) { if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 988a189d..e53eb184 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -8,6 +8,7 @@ import ( "image" "slices" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -32,7 +33,7 @@ const ( selfAttentionLayer ) -func New(c ml.Config) (model.Model, error) { +func New(c fs.Config) (model.Model, error) { // Verify unified config if c.Uint("vision.block_count") == 0 { return nil, fmt.Errorf("non-unified vision model not supported") diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1cf30d89..261897c3 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -4,6 +4,7 @@ import ( "math" "slices" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -220,7 +221,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask return m.Output.Forward(ctx, hiddenState) } -func newTextModel(c ml.Config) *TextModel { +func newTextModel(c fs.Config) *TextModel { var decoderLayers []TextDecoderLayer for i := range c.Uint("block_count") { var textDecoderLayer TextDecoderLayer diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index ac777f05..2f7d26ca 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -4,6 +4,7 @@ import ( "math" "slices" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" ) @@ -213,7 +214,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa return hiddenState.Concat(ctx, hiddenStates, 0) } -func newVisionModel(c ml.Config) *VisionModel { +func newVisionModel(c fs.Config) *VisionModel { return &VisionModel{ Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))}, GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))}, diff --git a/model/models/mllama/process_image.go b/model/models/mllama/process_image.go index c94d14a6..1b0506d3 100644 --- a/model/models/mllama/process_image.go +++ b/model/models/mllama/process_image.go @@ -8,14 +8,14 @@ import ( "golang.org/x/image/draw" - "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/fs" ) type ImageProcessor struct { imageSize, numChannels, maxNumTiles int } -func newImageProcessor(c ml.Config) ImageProcessor { +func newImageProcessor(c fs.Config) ImageProcessor { return ImageProcessor{ imageSize: int(c.Uint("vision.image_size")), numChannels: int(c.Uint("vision.num_channels")), From 6bd0a983cd2cf74f27df2e5a5c80f1794a2ed7ef Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 14 Mar 2025 16:56:32 -0700 Subject: [PATCH 17/17] model: support for mistral-small in the ollama runner Mistral is a popular research lab making open source models. This updates the forward pass of llama architecture models to support both llama models and mistral models by accounting for additional metadata present in mistral models, and finding the correct dimensions for the output projection. --- convert/convert.go | 4 +- convert/convert_mistral.go | 190 +++++++++++++++ convert/reader.go | 5 +- fs/ggml/ggml.go | 7 +- kvcache/causal_test.go | 26 ++- llama/llama.cpp/src/llama-arch.cpp | 17 ++ llama/llama.cpp/src/llama-arch.h | 1 + llama/llama.cpp/src/llama-model.cpp | 3 + llama/llama.cpp/src/llama-quant.cpp | 9 +- ...tch => 0021-add-model-quantizations.patch} | 102 ++++++-- llama/patches/0022-metal-add-op_neg.patch | 75 ++++++ ml/backend.go | 10 +- ml/backend/ggml/ggml.go | 56 +++++ .../src/ggml-metal/ggml-metal-embed.metal | 7 + .../ggml/ggml/src/ggml-metal/ggml-metal.m | 15 ++ .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 7 + model/models/gemma3/model_text.go | 22 +- model/models/mistral3/imageproc.go | 56 +++++ model/models/mistral3/model.go | 189 +++++++++++++++ model/models/mistral3/model_text.go | 177 ++++++++++++++ model/models/mistral3/model_vision.go | 186 +++++++++++++++ model/models/mllama/model_vision.go | 2 +- model/models/models.go | 1 + model/models/pixtral/imageproc.go | 68 ------ model/models/pixtral/imageproc_test.go | 219 ------------------ model/process_text.go | 4 + parser/parser.go | 8 +- 27 files changed, 1116 insertions(+), 350 deletions(-) create mode 100644 convert/convert_mistral.go rename llama/patches/{0021-gemma3-quantization.patch => 0021-add-model-quantizations.patch} (52%) create mode 100644 llama/patches/0022-metal-add-op_neg.patch create mode 100644 model/models/mistral3/imageproc.go create mode 100644 model/models/mistral3/model.go create mode 100644 model/models/mistral3/model_text.go create mode 100644 model/models/mistral3/model_vision.go delete mode 100644 model/models/pixtral/imageproc.go delete mode 100644 model/models/pixtral/imageproc_test.go diff --git a/convert/convert.go b/convert/convert.go index a31b0d6c..26bc72cc 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { var conv ModelConverter switch p.Architectures[0] { - case "LlamaForCausalLM", "MistralForCausalLM": + case "LlamaForCausalLM": conv = &llamaModel{} + case "Mistral3ForConditionalGeneration": + conv = &mistral3Model{} case "MixtralForCausalLM": conv = &mixtralModel{} case "GemmaForCausalLM": diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go new file mode 100644 index 00000000..6c224ae4 --- /dev/null +++ b/convert/convert_mistral.go @@ -0,0 +1,190 @@ +package convert + +import ( + "cmp" + "fmt" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type mistral3Model struct { + ModelParameters + ImageTokenIndex uint32 `json:"image_token_index"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + VisionFeatureLayer int32 `json:"vision_feature_layer"` + TextModel struct { + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RopeTheta float32 `json:"rope_theta"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + SlidingWindow *uint32 `json:"sliding_window"` + HiddenAct string `json:"hidden_act"` + VocabSize uint32 `json:"vocab_size"` + } `json:"text_config"` + VisionModel struct { + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + ImageSize uint32 `json:"image_size"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + HeadDim uint32 `json:"head_dim"` + HiddenAct string `json:"hidden_act"` + RopeTheta float32 `json:"rope_theta"` + } `json:"vision_config"` + MultiModalProjectorBias bool `json:"multimodal_projector_bias"` + ProjectorHiddenAct string `json:"projector_hidden_act"` +} + +func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "mistral3" + kv["mistral3.vocab_size"] = p.TextModel.VocabSize + + // Text configuration + kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers + kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings + kv["mistral3.embedding_length"] = p.TextModel.HiddenSize + kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize + kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads + kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads + kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS + kv["mistral3.attention.key_length"] = p.TextModel.HeadDim + kv["mistral3.attention.value_length"] = p.TextModel.HeadDim + kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers + kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta + + // Vision configuration + kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers + kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize + kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize + kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads + kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim + kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize + kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize + kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels + // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value + kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta + + // Multimodal configuration + kv["mistral3.image_token_index"] = p.ImageTokenIndex + kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize + + kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias + + if p.ProjectorHiddenAct != "" { + kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct + } + + return kv +} + +func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor { + var out []ggml.Tensor + + for _, t := range ts { + if !strings.HasPrefix(t.Name(), "v.") { + if strings.HasSuffix(t.Name(), ".attn_q.weight") || + strings.HasSuffix(t.Name(), ".attn_k.weight") { + t.SetRepacker(p.repack) + } + } + + out = append(out, ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *mistral3Model) Replacements() []string { + return []string{ + "language_model.model.norm", "output_norm", + "language_model.model.", "", + "language_model.", "", + "layers", "blk", + "transformer.layers", "blk", + "vision_tower", "v", + "ln_pre", "encoder_norm", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "embed_tokens", "token_embd", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "attention.q_proj", "attn_q", + "attention.k_proj", "attn_k", + "attention.v_proj", "attn_v", + "attention.o_proj", "attn_output", + "attention_norm", "attn_norm", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", + "multi_modal_projector", "mm", + "ffn_norm", "ffn_norm", + "lm_head", "output", + } +} + +func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) { + var dims []int + for _, dim := range shape { + dims = append(dims, int(dim)) + } + + var heads uint32 + if strings.HasSuffix(name, ".attn_q.weight") { + heads = p.TextModel.NumAttentionHeads + } else if strings.HasSuffix(name, ".attn_k.weight") { + heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads) + } else { + return nil, fmt.Errorf("unknown tensor for repack: %s", name) + } + + n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := n.Reshape(dims...); err != nil { + return nil, err + } + + if err := n.Transpose(); err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 1) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} diff --git a/convert/reader.go b/convert/reader.go index c1218e66..904b13a4 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { Pattern string Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error) }{ - {"model-*-of-*.safetensors", parseSafetensors}, - {"model.safetensors", parseSafetensors}, - {"adapters.safetensors", parseSafetensors}, - {"adapter_model.safetensors", parseSafetensors}, + {"*.safetensors", parseSafetensors}, {"pytorch_model-*-of-*.bin", parseTorch}, {"pytorch_model.bin", parseTorch}, {"consolidated.*.pth", parseTorch}, diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index c88583fb..9431e9cc 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -134,7 +134,10 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { } func (kv KV) OllamaEngineRequired() bool { - return kv.Architecture() == "gemma3" + return slices.Contains([]string{ + "gemma3", + "mistral3", + }, kv.Architecture()) } func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { @@ -638,7 +641,7 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { embeddingLength*numPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) - case "gemma3": + case "gemma3", "mistral3": graphSize = 4 * (imageSize*imageSize*numChannels + embeddingLength*patchSize + numPatches*numPatches*headCount) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 517e3726..bd63214c 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -484,6 +484,14 @@ func (t *testTensor) Floats() []float32 { return out } +func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { + out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) + for i := range out.data { + out.data[i] = -t.data[i] + } + return out +} + func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) @@ -538,17 +546,15 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, di panic("not implemented") } -func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { +func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { panic("not implemented") } -func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { - panic("not implemented") -} +func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { panic("not implemented") @@ -600,6 +606,8 @@ func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { panic("not implemented") } +func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") } + func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { panic("not implemented") } @@ -612,3 +620,5 @@ func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { copy(t2.(*testTensor).data, t.data) return nil } + +func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") } diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b443fcd3..13a0a988 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_SOLAR, "solar" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1371,6 +1372,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + } + }, { LLM_ARCH_UNKNOWN, { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index aad92a5d..8476ae0a 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_CHAMELEON, LLM_ARCH_SOLAR, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 70183041..db4f2685 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -1277,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_MISTRAL3: break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3537,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_MISTRAL3: break; default: throw std::runtime_error("unknown architecture"); } @@ -4015,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: case LLM_ARCH_SOLAR: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index d2f3a510..ebcbafa1 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -738,13 +738,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? // don't quantize vision stuff - quantize &= name.find("v.blk.") == std::string::npos; - - quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; - quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; - quantize &= name.find("v.patch_embedding.weight") == std::string::npos; - quantize &= name.find("v.position_embedding.weight") == std::string::npos; - quantize &= name.find("v.post_layernorm.weight") == std::string::npos; + quantize &= name.find("v.") == std::string::npos; + quantize &= name.find("mm.") == std::string::npos; // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0021-gemma3-quantization.patch b/llama/patches/0021-add-model-quantizations.patch similarity index 52% rename from llama/patches/0021-gemma3-quantization.patch rename to llama/patches/0021-add-model-quantizations.patch index 4f6dbc11..cdc35a41 100644 --- a/llama/patches/0021-gemma3-quantization.patch +++ b/llama/patches/0021-add-model-quantizations.patch @@ -1,17 +1,19 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 14 Mar 2025 16:33:23 -0700 -Subject: [PATCH] gemma3 quantization +Subject: [PATCH] add model quantizations +- gemma3 +- mistral3 --- - src/llama-arch.cpp | 19 +++++++++++++++++++ - src/llama-arch.h | 1 + - src/llama-model.cpp | 7 +++++++ - src/llama-quant.cpp | 9 +++++++++ - 4 files changed, 36 insertions(+) + src/llama-arch.cpp | 36 ++++++++++++++++++++++++++++++++++++ + src/llama-arch.h | 2 ++ + src/llama-model.cpp | 10 ++++++++++ + src/llama-quant.cpp | 4 ++++ + 4 files changed, 52 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index b6f20286..b443fcd3 100644 +index b6f20286..13a0a988 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { @@ -22,7 +24,15 @@ index b6f20286..b443fcd3 100644 { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, -@@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N +@@ -64,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_SOLAR, "solar" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, ++ { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, + }; + +@@ -804,6 +806,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, @@ -47,8 +57,31 @@ index b6f20286..b443fcd3 100644 { LLM_ARCH_STARCODER2, { +@@ -1352,6 +1372,22 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, ++ { ++ LLM_ARCH_MISTRAL3, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ } ++ }, + { + LLM_ARCH_UNKNOWN, + { diff --git a/src/llama-arch.h b/src/llama-arch.h -index ec742224..aad92a5d 100644 +index ec742224..8476ae0a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -41,6 +41,7 @@ enum llm_arch { @@ -59,8 +92,16 @@ index ec742224..aad92a5d 100644 LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, +@@ -68,6 +69,7 @@ enum llm_arch { + LLM_ARCH_CHAMELEON, + LLM_ARCH_SOLAR, + LLM_ARCH_WAVTOKENIZER_DEC, ++ LLM_ARCH_MISTRAL3, + LLM_ARCH_UNKNOWN, + }; + diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index ab1a07d1..70183041 100644 +index ab1a07d1..db4f2685 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { @@ -73,7 +114,15 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; ++ case LLM_ARCH_MISTRAL3: break; + default: throw std::runtime_error("unsupported model architecture"); + } + +@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; @@ -83,7 +132,23 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); -@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { +@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + } break; ++ case LLM_ARCH_MISTRAL3: break; + default: + throw std::runtime_error("unknown architecture"); + } +@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: + case LLM_ARCH_SOLAR: ++ case LLM_ARCH_MISTRAL3: + return LLAMA_ROPE_TYPE_NORM; + + // the pairs of head values are offset by n_rot/2 +@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: @@ -92,21 +157,16 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp -index 6eb1da08..d2f3a510 100644 +index 6eb1da08..ebcbafa1 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp -@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: +@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + // don't quantize vision stuff -+ quantize &= name.find("v.blk.") == std::string::npos; -+ -+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; -+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; -+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos; -+ quantize &= name.find("v.position_embedding.weight") == std::string::npos; -+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos; ++ quantize &= name.find("v.") == std::string::npos; ++ quantize &= name.find("mm.") == std::string::npos; + // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0022-metal-add-op_neg.patch b/llama/patches/0022-metal-add-op_neg.patch new file mode 100644 index 00000000..a903535f --- /dev/null +++ b/llama/patches/0022-metal-add-op_neg.patch @@ -0,0 +1,75 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Wed, 2 Apr 2025 15:26:15 -0700 +Subject: [PATCH] metal: add op_neg + +--- + ggml/src/ggml-metal/ggml-metal.m | 15 +++++++++++++++ + ggml/src/ggml-metal/ggml-metal.metal | 7 +++++++ + 2 files changed, 22 insertions(+) + +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index e4c093f9..d8422f1b 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, ++ GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, +@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); +@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: ++ case GGML_UNARY_OP_NEG: + return ggml_is_contiguous(op->src[0]); + default: + return false; +@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node( + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; ++ case GGML_UNARY_OP_NEG: ++ { ++ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; ++ ++ [encoder setComputePipelineState:pipeline]; ++ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; ++ [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; ++ ++ const int64_t n = ggml_nelements(dst); ++ ++ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; ++ } break; + default: + { + GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index f38909d0..bb0ff668 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -945,6 +945,13 @@ kernel void kernel_cos( + dst[tpig] = cos(src0[tpig]); + } + ++kernel void kernel_neg( ++ device const float * src0, ++ device float * dst, ++ uint tpig[[thread_position_in_grid]]) { ++ dst[tpig] = -src0[tpig]; ++} ++ + kernel void kernel_sum_rows( + device const float * src0, + device float * dst, diff --git a/ml/backend.go b/ml/backend.go index b22ba795..fffc04a4 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -118,6 +118,7 @@ type Tensor interface { Bytes() []byte Floats() []float32 + Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor @@ -132,7 +133,10 @@ type Tensor interface { Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + Sin(ctx Context) Tensor + Cos(ctx Context) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor SILU(ctx Context) Tensor @@ -147,9 +151,13 @@ type Tensor interface { Unpad(ctx Context, shape ...int) Tensor Stack(ctx Context, dim int, s ...Tensor) Tensor + + // Repeat repeats the tensor n times along dimension dim + Repeat(ctx Context, dim, n int) Tensor Concat(ctx Context, t2 Tensor, dim int) Tensor Rows(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor + Duplicate(ctx Context) Tensor } // ScaledDotProductAttention implements a fused attention @@ -214,7 +222,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeF16, DTypeQ80, DTypeQ40: - f32 := ctx.Empty(DTypeF32, t.Shape()...) + f32 := ctx.Input().Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 17f06384..a106fed5 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -711,6 +711,13 @@ func (t *Tensor) DType() ml.DType { } } +func (t *Tensor) Neg(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_neg(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -718,6 +725,27 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { + if dim < 0 || dim >= C.GGML_MAX_DIMS { + panic("invalid dimension") + } + + shape := make([]C.int64_t, C.GGML_MAX_DIMS) + for i := range C.GGML_MAX_DIMS { + if i == dim { + shape[i] = C.int64_t(t.Dim(i) * n) + } else { + shape[i] = C.int64_t(t.Dim(i)) + } + } + + tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape)) + return &Tensor{ + b: t.b, + t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl), + } +} + func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { if len(s) > 0 { return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim) @@ -854,6 +882,20 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) Sin(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_sin(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Cos(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_cos(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -942,6 +984,13 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } } +func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), + } +} + func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -1010,3 +1059,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } } + +func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_dup(ctx.(*Context).ctx, t.t), + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index c3610ac0..a2f599ce 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -3083,6 +3083,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index e4c093f9..d8422f1b 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -423,6 +423,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SQRT, GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, @@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); @@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: return ggml_is_contiguous(op->src[0]); default: return false; @@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_NEG: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index f38909d0..bb0ff668 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -945,6 +945,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 3b640a96..2d7bb20a 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -11,7 +11,7 @@ import ( "github.com/ollama/ollama/model/input" ) -type TextOptions struct { +type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int eps, ropeScale float32 @@ -28,7 +28,7 @@ type TextModel struct { OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` - *TextOptions + *TextConfig } const ( @@ -55,7 +55,7 @@ func newTextModel(c fs.Config) *TextModel { }, ), Layers: make([]TextLayer, numBlocks), - TextOptions: &TextOptions{ + TextConfig: &TextConfig{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), @@ -84,7 +84,7 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) ropeType := uint32(2) @@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextOptions.ropeLocalBase + ropeBase := m.TextConfig.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextOptions.ropeGlobalBase + ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil + return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil } type TextMLP struct { @@ -134,7 +134,7 @@ type TextMLP struct { Gate *nn.Linear `gguf:"ffn_gate"` } -func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { +func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -148,7 +148,7 @@ type TextLayer struct { PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` } -func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) // set image embeddings var except []int @@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor lastLayerOutputs = outputs } - hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/mistral3/imageproc.go b/model/models/mistral3/imageproc.go new file mode 100644 index 00000000..3d464bca --- /dev/null +++ b/model/models/mistral3/imageproc.go @@ -0,0 +1,56 @@ +package mistral3 + +import ( + "image" + _ "image/jpeg" + _ "image/png" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize int + patchSize int + numChannels int + longestEdge int +} + +func newImageProcessor(c fs.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + longestEdge: int(c.Uint("vision.longest_edge", 1540)), + } +} + +// ProcessImage prepares an image for the vision model by: +// 1. Compositing transparent images +// 2. Resizing to fit model constraints while preserving aspect ratio +// 3. Normalizing pixel values +// Returns normalized image data and the final size in pixels +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) { + img = imageproc.Composite(img) + + size := img.Bounds().Size() + ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge)) + if ratio > 1.0 { + size = image.Point{ + int(math.Floor(float64(size.X) / ratio)), + int(math.Floor(float64(size.Y) / ratio)), + } + } + + patchesX := (size.X-1)/p.patchSize + 1 + patchesY := (size.Y-1)/p.patchSize + 1 + size = image.Point{ + patchesX * p.patchSize, + patchesY * p.patchSize, + } + + img = imageproc.Resize(img, size, imageproc.ResizeBilinear) + data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) + return data, size, nil +} diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go new file mode 100644 index 00000000..fca3896c --- /dev/null +++ b/model/models/mistral3/model.go @@ -0,0 +1,189 @@ +package mistral3 + +import ( + "bytes" + "image" + "slices" + "sync" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + *TextModel + *VisionModel `gguf:"v,vision"` + *MultiModalProjector `gguf:"mm"` + + ImageProcessor +} + +// Implement MultimodalProcessor interface +var _ model.MultimodalProcessor = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + textModel, err := NewTextModel(c) + if err != nil { + return nil, err + } + + m := &Model{ + TextModel: textModel, + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), + } + + m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) + + return m, nil +} + +type PatchMerger struct { + MergingLayer *nn.Linear `gguf:"merging_layer"` +} + +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor { + d := visionOutputs.Dim(0) + imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d) + kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d) + patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) + reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) + return pm.MergingLayer.Forward(ctx, reshaped) +} + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"linear_1"` + Linear2 *nn.Linear `gguf:"linear_2"` + PatchMerger *PatchMerger `gguf:"patch_merger"` + + spatialMergeSize int + eps float32 + patchSize int +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) { + visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps) + patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize} + visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize) + visionOutputs = p.Linear1.Forward(ctx, visionOutputs) + visionOutputs = visionOutputs.GELU(ctx) + return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize} +} + +func newMultiModalProjector(c fs.Config) *MultiModalProjector { + return &MultiModalProjector{ + spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), + eps: c.Float("text_config.rms_norm_eps", 1e-5), + patchSize: int(c.Uint("vision.patch_size", 14)), + } +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + f32s, size, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) + if err != nil { + return nil, err + } + + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) + + // split into patches to be sent to the text transformer + parent := imageFeatures{tensor: features} + rows := make([]*imageRow, size.Y) + for i := range rows { + rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} + } + + return rows, nil +} + +type imageFeatures struct { + tensor ml.Tensor + + dataOnce sync.Once + data []float32 +} + +type imageRow struct { + parent *imageFeatures + s int + shape []int +} + +func (r *imageRow) data() []float32 { + n := 1 + for _, s := range r.shape { + n *= s + } + + return r.parent.data[r.s*n : (r.s+1)*n] +} + +// PostTokenize arranges Mistral 3's inputs for the forward pass +// In Mistral 3 and Pixtral, the input patches are arranged as follows: +// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] +// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings +// that can be processed together. +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + var result []input.Input + for _, inp := range inputs { + if inp.Multimodal == nil { + result = append(result, inp) + } else { + inputMultimodal := inp.Multimodal.([]*imageRow) + for i, row := range inputMultimodal { + // [IMG] + result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) + if i == len(inputMultimodal)-1 { + // [IMG_END] + result = append(result, input.Input{Token: 13}) + } else { + // [IMG_BREAK] + result = append(result, input.Input{Token: 12}) + } + } + } + } + + return result, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil +} + +func init() { + model.Register("mistral3", New) +} diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go new file mode 100644 index 00000000..c256cbf1 --- /dev/null +++ b/model/models/mistral3/model_text.go @@ -0,0 +1,177 @@ +package mistral3 + +import ( + "fmt" + "math" + "strings" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type TextOptions struct { + hiddenSize, numHeads, numKVHeads, headDim int + eps, ropeBase, ropeScale float32 + ropeDim uint32 +} + +type TextModel struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenState.Dim(1) + ropeType := uint32(0) + headDim := opts.headDim + if headDim == 0 { + headDim = opts.hiddenSize / opts.numHeads + } + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) + kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, kqv) +} + +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { + hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) + + // image embeddings + for _, image := range batch.Multimodal { + row := image.Multimodal.(*imageRow) + row.parent.dataOnce.Do(func() { + // use a new, throwaway context so the image tensor is not added to the graph + temp := m.Backend().NewContext() + temp.Forward(row.parent.tensor).Compute(row.parent.tensor) + row.parent.data = row.parent.tensor.Floats() + temp.Close() + }) + + imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...) + if err != nil { + panic(err) + } + + ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) + } + + for i, layer := range m.Layers { + cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState) +} + +func NewTextModel(c fs.Config) (*TextModel, error) { + if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { + return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) + } + + textModel := &TextModel{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + Layers: make([]Layer, c.Uint("block_count")), + TextOptions: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count"), + }, + } + + return textModel, nil +} diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go new file mode 100644 index 00000000..469dc40c --- /dev/null +++ b/model/models/mistral3/model_vision.go @@ -0,0 +1,186 @@ +package mistral3 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +var batchSize int = 1 + +func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { + x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) + x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) + return x2.Neg(ctx).Concat(ctx, x1, 0) +} + +func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { + return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +} + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + + query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type VisionEncoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *VisionSelfAttention + FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *VisionMLP +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type VisionModelOptions struct { + hiddenSize int + numHeads int + headDim int + intermediateSize int + imageSize int + patchSize int + numChannels int + eps float32 + ropeBase float32 +} + +type VisionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_conv"` + EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"` + Layers []VisionEncoderLayer `gguf:"blk"` + + *VisionModelOptions +} + +func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor { + maxPatchesPerSide := m.imageSize / m.patchSize + frequencies := m.headDim / 2 + frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide) + frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide) + for i := range frequencies { + for j := range maxPatchesPerSide { + frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim))) + if i%2 == 0 { + frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency + } else { + frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency + } + } + } + + h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) + if err != nil { + panic(err) + } + + w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) + if err != nil { + panic(err) + } + + h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + h = h.Repeat(ctx, 1, maxPatchesPerSide) + h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + w = w.Repeat(ctx, 2, maxPatchesPerSide) + + inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide) + inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0) + return inverseFrequencies.Rows(ctx, positionIDs) +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + numPatchesW := pixelValues.Dim(0) / m.patchSize + numPatchesH := pixelValues.Dim(1) / m.patchSize + numPatches := numPatchesW * numPatchesH + + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize) + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps) + + // Prepare position IDs for 2D rope + positions := make([]int32, numPatches) + for h := range numPatchesH { + for w := range numPatchesW { + idx := h*numPatchesW + w + positions[idx] = int32(h*m.imageSize/m.patchSize + w) + } + } + + positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) + if err != nil { + panic(err) + } + + positionEmbedding := m.positionalEmbedding(ctx, positionIDs) + cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) + cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) + sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions) + } + + return hiddenStates +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1024)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + headDim: int(c.Uint("vision.attention.key_length", 64)), + intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)), + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5), + ropeBase: c.Float("vision.rope.freq_base", 10000.0), + }, + } +} diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 2f7d26ca..8b10bde8 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -186,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions) - hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1) + hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1) hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions) hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/models.go b/model/models/models.go index ce1d2ce0..c5da2894 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -4,5 +4,6 @@ import ( _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/llama" + _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" ) diff --git a/model/models/pixtral/imageproc.go b/model/models/pixtral/imageproc.go deleted file mode 100644 index 16ec0c41..00000000 --- a/model/models/pixtral/imageproc.go +++ /dev/null @@ -1,68 +0,0 @@ -package pixtral - -import ( - "fmt" - "image" - _ "image/jpeg" - _ "image/png" - "io" - "math" - - "github.com/ollama/ollama/model/imageproc" -) - -func getNumImageTokens(imageSize, patchSize image.Point) image.Point { - return image.Point{ - (imageSize.X-1)/patchSize.X + 1, - (imageSize.Y-1)/patchSize.Y + 1, - } -} - -func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { - b := img.Bounds() - le := float64(longestEdge) - ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le) - - newSize := img.Bounds().Max - - if ratio > 1.0 { - newSize = image.Point{ - int(math.Ceil(float64(b.Max.X) / ratio)), - int(math.Ceil(float64(b.Max.Y) / ratio)), - } - } - - tokens := getNumImageTokens(newSize, patchSize) - return image.Point{ - tokens.X * patchSize.X, - tokens.Y * patchSize.Y, - } -} - -func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image { - if format == "png" { - img = imageproc.Composite(img) - } - - newSize := getResizeOutputImageSize(img, longestEdge, patchSize) - - // todo should be ResizeBicubic, but it doesn't exist - return imageproc.Resize(img, newSize, imageproc.ResizeBilinear) -} - -func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) { - img, format, err := image.Decode(imageData) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode image: %w", err) - } - - longestEdge := 1024 - patchSize := image.Point{16, 16} - - img = resizeImage(img, format, longestEdge, patchSize) - - data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - - opts := map[string]any{} - return data, opts, nil -} diff --git a/model/models/pixtral/imageproc_test.go b/model/models/pixtral/imageproc_test.go deleted file mode 100644 index 1d9e4ffe..00000000 --- a/model/models/pixtral/imageproc_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package pixtral - -import ( - "bytes" - "encoding/binary" - "image" - "image/png" - "math" - "os" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestGetNumImageTokens(t *testing.T) { - type numImageTokensCase struct { - ImageSize image.Point - PatchSize image.Point - Expected image.Point - } - - cases := []numImageTokensCase{ - { - ImageSize: image.Point{1024, 764}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{64, 48}, - }, - { - ImageSize: image.Point{800, 600}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{50, 38}, - }, - { - ImageSize: image.Point{640, 480}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{40, 30}, - }, - { - ImageSize: image.Point{320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{20, 13}, - }, - { - ImageSize: image.Point{1320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{83, 13}, - }, - { - ImageSize: image.Point{2000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{125, 13}, - }, - { - ImageSize: image.Point{10000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{625, 13}, - }, - { - ImageSize: image.Point{1131, 577}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{71, 37}, - }, - { - ImageSize: image.Point{16, 16}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1, 1}, - }, - } - - for _, c := range cases { - actual := getNumImageTokens(c.ImageSize, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestGetResizeOutputImageSize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Point - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 768}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 624}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 300, 200)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{304, 208}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 288}, - }, - } - - for _, c := range cases { - actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestResize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Image - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)), - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 10, 10)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)), - }, - } - - for _, c := range cases { - actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize) - - if actual.Bounds() != c.Expected.Bounds() { - t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) - } - } -} - -func TestPreprocess(t *testing.T) { - type preprocessCase struct { - TestImage image.Image - ExpectedLen int - } - - cases := []preprocessCase{ - { - TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), - ExpectedLen: 16 * 16 * 3 * 1, - }, - { - TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)), - ExpectedLen: 1024 * 1024 * 3 * 1, - }, - } - - for _, c := range cases { - var buf bytes.Buffer - err := png.Encode(&buf, c.TestImage) - if err != nil { - t.Fatal(err) - } - - imgData, _, err := Preprocess(&buf) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - switch len(imgData) { - case 0: - t.Errorf("no image data returned") - case c.ExpectedLen: - // ok - default: - t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen) - } - } -} - -func TestPreprocessImages(t *testing.T) { - for _, testFile := range []string{"flight.png", "sportsball.png"} { - f, err := os.Open(testFile) - if err != nil { - t.Skipf("skipping test, no test image found at %s", testFile) - } - defer f.Close() - - imgData, _, err := Preprocess(f) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes - for i, f := range imgData { - binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f)) - } - - outputPath := "processed_" + testFile + ".bin" - err = os.WriteFile(outputPath, byteData, 0o644) - if err != nil { - t.Fatalf("error writing processed image: %q", err) - } - } -} diff --git a/model/process_text.go b/model/process_text.go index 01af65b6..f0fb7787 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { continue } + if id := bpe.vocab.Encode(pair.value); id < 0 { + continue + } + merges[pair.a].runes = append(left.runes, right.runes...) merges[pair.b].runes = nil diff --git a/parser/parser.go b/parser/parser.go index 6832351f..9a98c8ea 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) { } var files []string - if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { + if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapters.safetensors - files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapter_model.safetensors - files = append(files, st...) } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin