From cda6f5c66c13b79a3dc9fb172ab5b17da9ce1bed Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sat, 1 Mar 2025 13:15:14 -0800 Subject: [PATCH 1/7] server/internal/internal/names: validate names (#9400) This commit is a step towards a goal to make names less ceremonial outside of the registry client. Clients of the registry package can treat names as opaque strings, and the registry package will handle parsing, validating, and normalizing names. Ideally we end up with the names package tucked away in an internal package for good. We'll see how things go. Also, this package name is not permanent. This another step in the on-going process of refactoring the server code, and at some point it will most likely be renamed/moved. --- server/internal/client/ollama/registry.go | 149 +++++++++++------- .../internal/client/ollama/registry_test.go | 66 ++++---- server/internal/internal/names/name.go | 111 +++++++------ server/internal/internal/names/name_test.go | 82 +++++++++- server/internal/registry/server.go | 2 +- server/internal/registry/server_test.go | 5 +- 6 files changed, 263 insertions(+), 152 deletions(-) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index e4c36d7d..82a8bbca 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -24,6 +24,7 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "sync/atomic" @@ -53,7 +54,7 @@ var ( // ErrMissingModel is returned when the model part of a name is missing // or invalid. - ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + ErrNameInvalid = errors.New("invalid or missing name") // ErrCached is passed to [Trace.PushUpdate] when a layer already // exists. It is a non-fatal error and is never returned by [Registry.Push]. @@ -205,10 +206,18 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // NameMask, if set, is the name used to convert non-fully qualified + // Mask, if set, is the name used to convert non-fully qualified // names to fully qualified names. If empty, the default mask // ("registry.ollama.ai/library/_:latest") is used. - NameMask string + Mask string +} + +func (r *Registry) completeName(name string) names.Name { + mask := defaultMask + if r.Mask != "" { + mask = names.Parse(r.Mask) + } + return names.Merge(names.Parse(name), mask) } // DefaultRegistry returns a new Registry configured from the environment. The @@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } -type PushParams struct { - // From is an optional destination name for the model. If empty, the - // destination name is the same as the source name. - From string -} - -// parseName parses name using [names.ParseExtended] and then merges the name with the -// default name, and checks that the name is fully qualified. If a digest is -// present, it parse and returns it with the other fields as their zero values. -// -// It returns an error if the name is not fully qualified, or if the digest, if -// any, is invalid. -// -// The scheme is returned as provided by [names.ParseExtended]. -func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) { - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - if !maskName.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask) - } - } - scheme, n, ds := names.ParseExtended(s) - if !n.IsValid() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - n = names.Merge(n, maskName) - if ds != "" { - // Digest is present. Validate it. - d, err = blob.ParseDigest(ds) - if err != nil { - return "", names.Name{}, blob.Digest{}, err - } - } - - // The name check is deferred until after the digest check because we - // say that digests take precedence over names, and so should there - // errors when being parsed. - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - - scheme = cmp.Or(scheme, "https") - return scheme, n, d, nil -} - func (r *Registry) maxStreams() int { n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) @@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 { return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) } +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + // Push pushes the model with the name in the cache to the remote registry. func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { if p == nil { @@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -431,7 +400,7 @@ func canRetry(err error) bool { // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { return err } @@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - _, n, _, err := parseName(name, r.NameMask) - if err != nil { - return false, err + n := r.completeName(name) + if !n.IsFullyQualified() { + return false, fmt.Errorf("%w: %q", ErrNameInvalid, name) } return c.Unlink(n.String()) } @@ -658,9 +627,9 @@ type Layer struct { } // ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.ParseExtended] but the scheme is ignored. +// parsed using [names.Split] but the scheme is ignored. func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.NameMask) + _, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.NameMask) + scheme, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error { } return err } + +type publicError struct { + wrapped error + message string +} + +func withPublicMessagef(err error, message string, args ...any) error { + return publicError{wrapped: err, message: fmt.Sprintf(message, args...)} +} + +func (e publicError) Error() string { return e.message } +func (e publicError) Unwrap() error { return e.wrapped } + +var supportedSchemes = []string{ + "http", + "https", + "https+insecure", +} + +var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) + +// parseName parses and validates an extended name, returning the scheme, name, +// and digest. +// +// If the scheme is empty, scheme will be "https". If an unsupported scheme is +// given, [ErrNameInvalid] wrapped with a display friendly message is returned. +// +// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly +// message is returned. +// +// If the name is not, once merged with the mask, fully qualified, +// [ErrNameInvalid] wrapped with a display friendly message is returned. +func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := names.Split(s) + scheme = cmp.Or(scheme, "https") + if !slices.Contains(supportedSchemes, scheme) { + err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) + return "", names.Name{}, blob.Digest{}, err + } + + var d blob.Digest + if digest != "" { + var err error + d, err = blob.ParseDigest(digest) + if err != nil { + err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest) + return "", names.Name{}, blob.Digest{}, err + } + if name == "" { + // We have can resolve a manifest from a digest only, + // so skip name validation and return the scheme and + // digest. + return scheme, names.Name{}, d, nil + } + } + + maskName := defaultMask + if mask != "" { + maskName = names.Parse(mask) + } + n := names.Merge(names.Parse(name), maskName) + if !n.IsFullyQualified() { + return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) + } + return scheme, n, d, nil +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index af898c26..20a1f159 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } } - rc := &Registry{ + r := &Registry{ HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, } link := func(name string, manifest string) { - _, n, _, err := parseName(name, rc.NameMask) + _, n, _, err := parseName(name, r.Mask) if err != nil { panic(err) } @@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) link("invalid", "!!!!!") - return rc, c + return r, c } func okHandler(w http.ResponseWriter, r *http.Request) { @@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { return d } -func TestRegistryPushInvalidNames(t *testing.T) { - rc, c := newClient(t, nil) - - cases := []struct { - name string - err error - }{ - {"", ErrNameInvalid}, - {"@", ErrNameInvalid}, - {"@x", blob.ErrInvalidDigest}, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - // Create a new registry and push a new image. - err := rc.Push(t.Context(), c, tt.name, nil) - if !errors.Is(err, tt.err) { - t.Errorf("err = %v; want %v", err, tt.err) - } - }) - } -} - func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} return WithTrace(ctx, t), t @@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) { })) defer s.Close() - const name = "ollama.com/library/insecure" + const name = "library/insecure" var rc Registry url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) @@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) { }) } } + +// TestParseNameErrors tests that parseName returns errors messages with enough +// detail for users to debug naming issues they may encounter. Previous to this +// test, the error messages were not very helpful and each problem was reported +// as the same message. +// +// It is only for testing error messages, not that all invalids and valids are +// covered. Those are in other tests for names.Name and blob.Digest. +func TestParseNameErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"x", nil, ""}, + {"x@", nil, ""}, + + {"", ErrNameInvalid, `invalid or missing name: ""`}, + {"://", ErrNameInvalid, `invalid or missing name: "://"`}, + {"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`}, + + {"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + {"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + } + + for _, tt := range cases { + _, _, _, err := parseName(tt.name, DefaultMask) + if !errors.Is(err, tt.err) { + t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) + } + if err != nil && !strings.Contains(err.Error(), tt.want) { + t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want) + } + } +} diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go index 361cce76..f0a1185d 100644 --- a/server/internal/internal/names/name.go +++ b/server/internal/internal/names/name.go @@ -8,7 +8,7 @@ import ( "github.com/ollama/ollama/server/internal/internal/stringsx" ) -const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: +const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // //: type Name struct { // Make incomparable to enfoce use of Compare / Equal for @@ -25,19 +25,12 @@ type Name struct { // format of a valid name string is: // // s: -// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } // { host } "/" { namespace } "/" { model } ":" { tag } -// { host } "/" { namespace } "/" { model } "@" { digest } // { host } "/" { namespace } "/" { model } -// { namespace } "/" { model } ":" { tag } "@" { digest } // { namespace } "/" { model } ":" { tag } -// { namespace } "/" { model } "@" { digest } // { namespace } "/" { model } -// { model } ":" { tag } "@" { digest } // { model } ":" { tag } -// { model } "@" { digest } // { model } -// "@" { digest } // host: // pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* // length: [1, 350] @@ -50,9 +43,6 @@ type Name struct { // tag: // pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* // length: [1, 80] -// digest: -// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* -// length: [1, 80] // // The name returned is not guaranteed to be valid. If it is not valid, the // field values are left in an undefined state. Use [Name.IsValid] to check @@ -82,23 +72,17 @@ func Parse(s string) Name { } } -// ParseExtended parses and returns any scheme, Name, and digest from from s in -// the the form [scheme://][name][@digest]. All parts are optional. -// -// If the scheme is present, it must be followed by "://". The digest is -// prefixed by "@" and comes after the name. The name is parsed using [Parse]. -// -// The scheme and digest are stripped before the name is parsed by [Parse]. -// -// For convience, the scheme is never empty. If the scheme is not present, the -// returned scheme is "https". +// Split splits an extended name string into its scheme, name, and digest +// parts. // // Examples: // // http://ollama.com/bmizerany/smol:latest@digest // https://ollama.com/bmizerany/smol:latest // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. -func ParseExtended(s string) (scheme string, _ Name, digest string) { +// model@digest +// @digest +func Split(s string) (scheme, name, digest string) { i := strings.Index(s, "://") if i >= 0 { scheme = s[:i] @@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) { digest = s[i+1:] s = s[:i] } - return scheme, Parse(s), digest -} - -func FormatExtended(scheme string, n Name, digest string) string { - var b strings.Builder - if scheme != "" { - b.WriteString(scheme) - b.WriteString("://") - } - b.WriteString(n.String()) - if digest != "" { - b.WriteByte('@') - b.WriteString(digest) - } - return b.String() + return scheme, s, digest } // Merge merges two names into a single name. Non-empty host, namespace, and @@ -141,39 +111,68 @@ func Merge(a, b Name) Name { // IsValid returns true if the name is valid. func (n Name) IsValid() bool { - if n.h != "" && !isValidHost(n.h) { + if n.h != "" && !isValidPart(partHost, n.h) { return false } - if n.n != "" && !isValidNamespace(n.n) { + if n.n != "" && !isValidPart(partNamespace, n.n) { return false } - if n.m != "" && !isValidModel(n.m) { + if n.t != "" && !isValidPart(partTag, n.t) { return false } - if n.t != "" && !isValidTag(n.t) { - return false - } - return true + + // at bare minimum, model must be present and valid + return n.m != "" && isValidPart(partModel, n.m) } func (n Name) IsFullyQualified() bool { return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" } -func isValidHost(_ string) bool { - return true // TODO: implement +const ( + partHost = iota + partNamespace + partModel + partTag +) + +func isValidPart(kind int, s string) bool { + maxlen := 80 + if kind == partHost { + maxlen = 350 + } + if len(s) > maxlen { + return false + } + + for i := range s { + if i == 0 { + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + continue + } + switch s[i] { + case '_', '-': + case '.': + if kind == partNamespace { + return false + } + case ':': + if kind != partHost { + return false + } + default: + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + } + } + return true } -func isValidNamespace(_ string) bool { - return true // TODO: implement -} - -func isValidModel(_ string) bool { - return true // TODO: implement -} - -func isValidTag(_ string) bool { - return true // TODO: implement +func isAlphanumericOrUnderscore(c byte) bool { + return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_' } func (n Name) Host() string { return n.h } diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go index 760fec5f..e3dc5fe3 100644 --- a/server/internal/internal/names/name_test.go +++ b/server/internal/internal/names/name_test.go @@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) { } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - scheme, name, digest := ParseExtended(tt.in) - if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + scheme, name, digest := Split(tt.in) + n := Parse(name) + if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest { t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) } - - // Round trip - if got := FormatExtended(scheme, name, digest); got != tt.in { - t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) - } }) } } @@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) { junkName = Parse("h/n/m:t") } } + +const ( + part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888" + part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333" +) + +var testCases = map[string]bool{ // name -> valid + "": false, + + "_why/_the/_lucky:_stiff": true, + + // minimal + "h/n/m:t": true, + + "host/namespace/model:tag": true, + "host/namespace/model": true, + "namespace/model": true, + "model": true, + + // long (but valid) + part80 + "/" + part80 + "/" + part80 + ":" + part80: true, + part350 + "/" + part80 + "/" + part80 + ":" + part80: true, + + // too long + part80 + "/" + part80 + "/" + part80 + ":" + part350: false, + "x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false, + + "h/nn/mm:t": true, // bare minimum part sizes + + // unqualified + "m": true, + "n/m:": true, + "h/n/m": true, + "@t": false, + "m@d": false, + + // invalids + "^": false, + "mm:": true, + "/nn/mm": true, + "//": false, // empty model + "//mm": true, + "hh//": false, // empty model + "//mm:@": false, + "00@": false, + "@": false, + + // not starting with alphanum + "-hh/nn/mm:tt": false, + "hh/-nn/mm:tt": false, + "hh/nn/-mm:tt": false, + "hh/nn/mm:-tt": false, + + // smells like a flag + "-h": false, + + // hosts + "host:https/namespace/model:tag": true, + + // colon in non-host part before tag + "host/name:space/model:tag": false, +} + +func TestParseNameValidation(t *testing.T) { + for s, valid := range testCases { + got := Parse(s) + if got.IsValid() != valid { + t.Logf("got: %v", got) + t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid()) + } + } +} diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 8eb6daf8..6ea590a7 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { return err } if !ok { - return &serverError{404, "manifest_not_found", "manifest not found"} + return &serverError{404, "not_found", "model not found"} } return nil } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 22267ba7..7ba13d50 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) { got = s.send(t, "DELETE", "/api/delete", ``) checkErrorResponse(t, got, 400, "bad_request", "empty request body") - got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`) - checkErrorResponse(t, got, 404, "manifest_not_found", "not found") - got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`) - checkErrorResponse(t, got, 400, "bad_request", "invalid name") + checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body checkErrorResponse(t, got, 404, "not_found", "not found") From e75c6126e92e592440b0b6b0966bcf4a0868e05a Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 1 Mar 2025 14:02:19 -0800 Subject: [PATCH 2/7] build: set GGML_CUDA_NO_VMM for ggml-hip target (#9449) --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 92b1793b..a727e99f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,9 +106,11 @@ if(CMAKE_HIP_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) if (WIN32) - target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1) + target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) endif() + target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_VMM) + set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCIES From 96a97adf9b973721f1e7401e49984cc1b772cc3b Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 1 Mar 2025 17:00:31 -0800 Subject: [PATCH 3/7] build: use correct GGML_HIP_NO_VMM compiler definition for ggml-hip (#9451) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a727e99f..034fc7d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,7 +109,7 @@ if(CMAKE_HIP_COMPILER) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) endif() - target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_VMM) + target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip From 854a9195f351ffe7c8aaaad34d19022144963e51 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Sat, 22 Feb 2025 21:34:10 -0800 Subject: [PATCH 4/7] attention: Remove unnecessary contiguous operations Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it. --- kvcache/cache.go | 11 ++ kvcache/causal.go | 174 +++++++++++++++++++++++------- kvcache/encoder.go | 29 +++++ kvcache/wrapper.go | 6 ++ ml/backend.go | 25 +++++ ml/backend/ggml/ggml.go | 9 +- ml/nn/attention.go | 57 ++++++---- model/models/llama/model.go | 9 +- model/models/mllama/model.go | 4 +- model/models/mllama/model_text.go | 32 +++--- 10 files changed, 270 insertions(+), 86 deletions(-) diff --git a/kvcache/cache.go b/kvcache/cache.go index 5d8b2f9b..2541f7c1 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -29,6 +29,17 @@ type Cache interface { // cache implementation used. Put(ctx ml.Context, key, value ml.Tensor) + // SetConfig controls optimizations (mostly backend-specific) that may transform + // the output of the cache to work better with specific kernels. If not called, + // the backend settings will be used. This works well when calling Attention. + // + // The config can be overridden by models, especially if they require vanilla + // output when implementing their own version of attention. To do this, pass + // an empty ml.CacheConfig. + // + // Most models will not need to use this. + SetConfig(ml.CacheConfig) + // ** cache management ** // Init sets up runtime parameters diff --git a/kvcache/causal.go b/kvcache/causal.go index 69068439..1d4daf80 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -22,6 +22,9 @@ type Causal struct { Capacity int32 windowSize int32 + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -75,14 +78,34 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding == 0 { + c.config.CachePadding = 1 + } + c.DType = dtype - c.Capacity = capacity - c.cells = make([]cacheCell, capacity) + c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) + c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend c.cacheCtx = backend.NewContext() } +func (c *Causal) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *Causal) Close() { c.cacheCtx.Close() } @@ -157,36 +180,73 @@ func (c *Causal) findStartLoc() (int, error) { return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) } +func roundDown(length, pad int) int { + return (length / pad) * pad +} + +func roundUp(length, pad int) int { + return ((length + pad - 1) / pad) * pad +} + // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do padding, which is required for flash attention - len := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*len) + // TODO(jessegross): This does not do mask padding, which is required for flash attention + // Align and pad the cache range as required by the backend + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) + c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + + length := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, c.curBatchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || c.cells[j].pos < positions[i]-c.windowSize { - mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } } - return ctx.FromFloatSlice(mask, len, c.curBatchSize) + return ctx.FromFloatSlice(mask, length, c.curBatchSize) } -func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { - for _, obj := range objs { - if obj == nil { +func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { + for i := range c.keys { + if c.keys[i] == nil { continue } - srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) - dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + key := c.keys[i] - ctx.Forward(srcView.Copy(ctx, dstView)) + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + + value := c.values[i] + var vSrcView, vDstView ml.Tensor + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + } + + ctx.Forward( + kSrcView.Copy(ctx, kDstView), + vSrcView.Copy(ctx, vDstView), + ) } } @@ -238,8 +298,7 @@ func (c *Causal) defrag() { pendingLen++ break } else { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } } @@ -263,8 +322,7 @@ func (c *Causal) defrag() { } if pendingLen > 0 { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } @@ -305,35 +363,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer] - key = key.View(ctx, key.Stride(2)*c.curCellRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - c.curMask.Dim(0), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + cachedSize := c.curMask.Dim(0) + + key = key.View(ctx, rowSize*c.curCellRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + cachedSize, ) - value = value.View(ctx, key.Stride(2)*c.curCellRange.min, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - c.curMask.Dim(0), - ) + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + value = value.View(ctx, elemSize*c.curCellRange.min, + cachedSize, value.Stride(1), + vHeadDim, value.Stride(2), + numKVHeads, + ) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + value = value.View(ctx, rowSize*c.curCellRange.min, + vHeadDim, value.Stride(1), + numKVHeads, value.Stride(2), + cachedSize, + ) + } return key, value, c.curMask } func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { - if c.curBatchSize != key.Dim(2) { - panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + kHeadDim := key.Dim(0) + vHeadDim := value.Dim(0) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + + if c.curBatchSize != batchSize { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) } if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + + if c.config.PermutedV { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + } else { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + } } - ctx.Forward( - key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), - value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))), - ) + rowSize := c.keys[c.curLayer].Stride(2) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + + if c.config.PermutedV { + elemSize := c.values[c.curLayer].Stride(0) + + value = value.Permute(ctx, 1, 2, 0, 3) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + } else { + rowSize := c.values[c.curLayer].Stride(2) + + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + } } func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { @@ -389,9 +485,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } - key = key.View(ctx, key.Stride(2)*seqRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*seqRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), size, ) diff --git a/kvcache/encoder.go b/kvcache/encoder.go index b85b1046..c55da2b4 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -1,6 +1,8 @@ package kvcache import ( + "fmt" + "github.com/ollama/ollama/ml" ) @@ -11,6 +13,9 @@ import ( // // Not currently safe for multiple sequences type EncoderCache struct { + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache { } func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { + panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) + } + c.cacheCtx = backend.NewContext() } +func (c *EncoderCache) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *EncoderCache) Close() { c.cacheCtx.Close() } @@ -75,6 +100,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { c.encoderPos = c.curPos c.encoderCached = true + if c.config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3) + } + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 2d4c1089..76956a88 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) } } +func (c *WrapperCache) SetConfig(config ml.CacheConfig) { + for _, cache := range c.caches { + cache.SetConfig(config) + } +} + func (c *WrapperCache) Close() { for _, cache := range c.caches { cache.Close() diff --git a/ml/backend.go b/ml/backend.go index 07bc75b6..ccab915c 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -27,6 +27,27 @@ type Backend interface { SystemInfo() string } +// BackendCacheConfig should be implemented by backends that need special output +// from the cache to meet specific requirements. It is frequently implemented in +// conjunction with ScaledDotProductAttention. +type BackendCacheConfig interface { + CacheConfig() CacheConfig +} + +// CacheConfig controls optimizations (mostly backend-specific) that may transform +// the output the cache to work better with specific kernels. +type CacheConfig struct { + // CachePadding specifies the multiple for the number of tokens of cache history + // that will be returned from cache Get for k, v and mask. The capacity of the + // cache itself will also be increased to a multiple of this size if needed. + CachePadding int + + // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put + // and return the permuted version via Get. This uses the cache copy operation + // to avoid a Contiguous call on the permuted tensor. + PermutedV bool +} + // BackendParams controls how the backend loads and executes models type BackendParams struct { // NumThreads sets the number of threads to use if running on the CPU @@ -116,6 +137,10 @@ type Tensor interface { // operation equivalent to following code on a tensor named // query: // +// query = query.Permute(ctx, 0, 2, 1, 3) +// key = key.Permute(ctx, 0, 2, 1, 3) +// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) +// // kq := key.MulmatFullPrec(ctx, query) // // kq = kq.Scale(ctx, scale) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 7f91990c..bddaad46 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -247,6 +247,10 @@ func (b *Backend) NewContext() ml.Context { } } +func (b *Backend) CacheConfig() ml.CacheConfig { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} +} + type Context struct { b *Backend ctx *C.struct_ggml_context @@ -661,7 +665,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kqMask = mask.(*Tensor).t } - kq := key.MulmatFullPrec(ctx, t) + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + + kq := key.MulmatFullPrec(ctx, query) kq = &Tensor{ t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), } diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 4f0c9fa1..a3f43a1e 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -3,6 +3,7 @@ package nn import ( "fmt" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" ) @@ -11,40 +12,50 @@ import ( // // Parameters: // - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] -// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] -// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] -// - mask: Optional attention mask that is added to the attention score. If -// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] +// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value // // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + if key != nil && value != nil { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if key.Dim(1) != value.Dim(1) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if cache != nil { + cache.Put(ctx, key, value) + } + } else if cache == nil { + panic("key & value tensors must be provided if cache is nil") } - if mask != nil && query.Dim(1) != mask.Dim(1) { - panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + var mask ml.Tensor + if cache != nil { + key, value, mask = cache.Get(ctx) } - if key.Dim(1) != value.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) - } - - if mask != nil && key.Dim(1) != mask.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) - } - - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) - } - - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + // Only use the fast SDPA implementation if we have a cache, since that's what + // will do any expected backend-specific transformations for us + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) } else { + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6106af86..9bf6f497 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) - - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 9b35a262..743f4c32 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) { TextModel: newTextModel(c), } - m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + encoderCache := kvcache.NewEncoderCache() + encoderCache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift)) return &m, nil } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 003bf9cb..e294b4c7 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, key, value) - key, value, mask := cache.Get(ctx) - - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - // This will only get called for layers in the cache, which are just the self attention layers + // This will only get called for layers in the causal cache, which are just the self attention layers return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil } @@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value, mask ml.Tensor + var key, value ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) cache.Put(ctx, key, value) - } else { - key, value, mask = cache.Get(ctx) } - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + key, value, _ = cache.Get(ctx) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scaleFactor) + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) From 55e5776c44659a153fa1b9a0316ec1b6a834a4b8 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 27 Feb 2025 14:52:39 -0800 Subject: [PATCH 5/7] ggml-backend: Store parent backend as part of tensor It can be important for a tensor to know what backend it came from - for example, to know if flash attention is enabled. --- ml/backend/ggml/ggml.go | 42 ++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index bddaad46..24943111 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -219,7 +219,7 @@ func (b *Backend) Get(name string) ml.Tensor { for _, c := range append(b.gpus, b.cpus...) { if t := C.ggml_get_tensor(c.ctx, cname); t != nil { - return &Tensor{t: t} + return &Tensor{b: b, t: t} } } @@ -330,7 +330,7 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_set_zero(t) - return &Tensor{t: t} + return &Tensor{b: c.b, t: t} } func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { @@ -339,7 +339,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u if n == 0 { var shape C.int64_t = 0 t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } for _, v := range shape { @@ -354,7 +354,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { @@ -372,6 +372,7 @@ func (c *Context) Close() { } type Tensor struct { + b *Backend t *C.struct_ggml_tensor sync func() } @@ -438,6 +439,7 @@ func (t *Tensor) DType() ml.DType { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -452,24 +454,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), } } func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cont(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -479,12 +485,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) return &Tensor{ + b: t.b, t: mul, } } func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) if b != nil { tt = tt.Add(ctx, b) } @@ -493,7 +500,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { @@ -502,6 +509,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -512,18 +520,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -532,18 +543,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), } case 2: return &Tensor{ + b: t.b, t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), } case 3: return &Tensor{ + b: t.b, t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), } case 4: return &Tensor{ + b: t.b, t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), } default: @@ -553,18 +568,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), } } func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), } } @@ -575,6 +593,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -583,10 +602,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), } case 3: return &Tensor{ + b: t.b, t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.size_t(shape[1]), @@ -594,6 +615,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 5: return &Tensor{ + b: t.b, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.size_t(shape[1]), C.size_t(shape[3]), @@ -601,6 +623,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 7: return &Tensor{ + b: t.b, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), @@ -617,7 +640,7 @@ const ( func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { - ropeFactors = &Tensor{} + ropeFactors = &Tensor{b: t.b} } dequant := t.t @@ -626,6 +649,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } return &Tensor{ + b: t.b, t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), @@ -643,18 +667,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), } } @@ -670,6 +697,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kq := key.MulmatFullPrec(ctx, query) kq = &Tensor{ + b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), } From ee141cc82174ff235c9f26f661c7f9d2cd36d312 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 28 Feb 2025 17:48:07 -0800 Subject: [PATCH 6/7] ml: Empty tensor constructor for tensors In cases where we allocate a tensor and then fully overwrite it with copied data, it is wasteful to first zero out the memory. --- kvcache/causal_test.go | 12 ++++++++---- kvcache/encoder.go | 4 ++-- ml/backend.go | 3 ++- ml/backend/ggml/ggml.go | 24 +++++++++++++++++------- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bd7d0ae8..84d8de54 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string { type testContext struct{} -func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { total := 0 if len(shape) > 0 { @@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} } +func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + return c.Empty(dtype, shape...) +} + func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) + t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) @@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 { } func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) + out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) for i := range out.data { out.data[i] = t.data[i] + t2.(*testTensor).data[i] @@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { context := &testContext{} - view := context.Zeros(t.dtype, s...).(*testTensor) + view := context.Empty(t.dtype, s...).(*testTensor) view.data = t.data[offset : offset+len(view.data)] return view diff --git a/kvcache/encoder.go b/kvcache/encoder.go index c55da2b4..39b4cdfb 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -105,8 +105,8 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { } if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) - c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...) + c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...) } ctx.Forward( diff --git a/ml/backend.go b/ml/backend.go index ccab915c..de2725c0 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -82,6 +82,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) { } type Context interface { + Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) @@ -195,7 +196,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeF16: - f32 := ctx.Zeros(DTypeF32, t.Shape()...) + f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 24943111..2c7e856c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -304,7 +304,7 @@ func shapeToGGML(shape []int) *C.int64_t { return &sh[0] } -func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor { if len(shape) < 1 || len(shape) > 4 { panic("unsupported number of dimensions") } @@ -318,19 +318,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { var t *C.struct_ggml_tensor switch dtype { case ml.DTypeF32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeF16: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeI32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) default: panic("unsupported dtype") } - b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) + b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - C.ggml_set_zero(t) - return &Tensor{b: c.b, t: t} + if zero { + C.ggml_set_zero(t) + } + return &Tensor{b: ctx.b, t: t} +} + +func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { + return newTensor(c, dtype, false, shape) +} + +func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + return newTensor(c, dtype, true, shape) } func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { From 21aa666a1eeff87d3fc6f4f8a43167e1fdd0d3ad Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 25 Feb 2025 17:24:36 -0800 Subject: [PATCH 7/7] ml: Enable support for flash attention The GGML flash attention kernel has specific requirements for padding and permutation. This adds support to the KV cache for conforming to these requirements so that flash attention can be enabled. Flash attention can be used in the same situations as the llama engine and is enabled by the user in the same way. --- kvcache/causal.go | 34 ++++++++++++++++++++++++++++---- ml/backend.go | 11 +++++++++++ ml/backend/ggml/ggml.go | 37 ++++++++++++++++++++++++----------- runner/ollamarunner/runner.go | 12 ++++++------ 4 files changed, 73 insertions(+), 21 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 1d4daf80..b2e7b3ab 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -90,6 +90,14 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { c.config.CachePadding = 1 } + if c.config.MaskBatchPadding == 0 { + c.config.MaskBatchPadding = 1 + } + + if c.config.MaskDType == ml.DTypeOther { + c.config.MaskDType = ml.DTypeF32 + } + c.DType = dtype c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) c.cells = make([]cacheCell, c.Capacity) @@ -192,13 +200,14 @@ func roundUp(length, pad int) int { // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do mask padding, which is required for flash attention - // Align and pad the cache range as required by the backend + // Align and pad the two dimensions as required by the backend + batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 length := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*length) + mask := make([]float32, batchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { @@ -209,7 +218,24 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te } } - return ctx.FromFloatSlice(mask, length, c.curBatchSize) + // Mask out any padding tokens we added. For padding that we added to the cache history, this + // has already been masked out because the sequence doesn't match. + for i := c.curBatchSize * length; i < len(mask); i++ { + mask[i] = float32(math.Inf(-1)) + } + + maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize) + if err != nil { + return nil, err + } + + if c.config.MaskDType != ml.DTypeF32 { + out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...) + ctx.Forward(maskTensor.Copy(ctx, out)) + maskTensor = out + } + + return maskTensor, nil } func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { diff --git a/ml/backend.go b/ml/backend.go index de2725c0..83b7a8c9 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -46,6 +46,14 @@ type CacheConfig struct { // and return the permuted version via Get. This uses the cache copy operation // to avoid a Contiguous call on the permuted tensor. PermutedV bool + + // MaskDType specifies the data type for generating the mask. If unset it will + // default to DTypeF32. + MaskDType DType + + // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. + // Any position that does not correspond to an actual token will be filled with -Inf. + MaskBatchPadding int } // BackendParams controls how the backend loads and executes models @@ -61,6 +69,9 @@ type BackendParams struct { // TensorSplit is the fraction of the model to offload to each GPU TensorSplit []float32 + + // FlashAttention indicates that we should use a fused flash attention kernel + FlashAttention bool } var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2c7e856c..f4948fca 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device { }) type Backend struct { + flashAttention bool + meta *fs.GGML cpus, gpus []Context tensors map[string]*Context @@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { } return &Backend{ - meta: meta, - cpus: cpus, - gpus: gpus, + flashAttention: params.FlashAttention, + meta: meta, + cpus: cpus, + gpus: gpus, sched: C.ggml_backend_sched_new( (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), @@ -248,7 +251,11 @@ func (b *Backend) NewContext() ml.Context { } func (b *Backend) CacheConfig() ml.CacheConfig { - return ml.CacheConfig{CachePadding: 32, PermutedV: true} + if b.flashAttention { + return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} + } else { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} + } } type Context struct { @@ -705,14 +712,22 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T query := t.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) - kq := key.MulmatFullPrec(ctx, query) - kq = &Tensor{ - b: t.b, - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } + if t.b.flashAttention { + value = value.Permute(ctx, 0, 2, 1, 3) - kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) + C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) + return &Tensor{b: t.b, t: kqv} + } else { + kq := key.MulmatFullPrec(ctx, query) + kq = &Tensor{ + b: t.b, + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } } func (b *Backend) SystemInfo() string { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index db9b271e..5705931a 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -818,7 +818,7 @@ func Execute(args []string) error { batchSize := fs.Int("batch-size", 512, "Batch size") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU") - _ = fs.Bool("flash-attn", false, "Enable flash attention") + flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") @@ -863,7 +863,6 @@ func Execute(args []string) error { } // TODO(jessegross): Parameters that need to be implemented: - // flash-attn // no-mmap // mlock @@ -878,10 +877,11 @@ func Execute(args []string) error { } params := ml.BackendParams{ - NumThreads: *threads, - NumGPULayers: *numGPULayers, - MainGPU: *mainGPU, - TensorSplit: tensorSplitFloats, + NumThreads: *threads, + NumGPULayers: *numGPULayers, + MainGPU: *mainGPU, + TensorSplit: tensorSplitFloats, + FlashAttention: *flashAttention, } server.ready.Add(1)