From 1108d8b34e43e968812eded0ccda73503ccad77d Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 2 Dec 2025 15:39:27 -0800 Subject: [PATCH] ggml: Enable flash attention for vision encoders Although the vision component of multimodal models typically already call the optimized nn.Attention, it is converted into non-fused operations. That is because the backend-specific fused kernels may have requirements, such as padding, and they is performed by the cache, which vision encoders don't use. This implements a fallback path in the backend, softening the requirements into optimizations. In turn, this allows flash attention to be used for vision encoders, saving a significant amount of VRAM and improving performance. --- ml/backend.go | 4 +++- ml/backend/ggml/ggml.go | 24 +++++++++++++++++++++++- ml/nn/attention.go | 7 +++---- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index 4d930fe4..6e5a059a 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -233,8 +233,10 @@ type Tensor interface { // // kqv := value.Mulmat(ctx, kq) // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +// +// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed type ScaledDotProductAttention interface { - ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor + ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor } type number interface { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f1a19e0b..18bdc91e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { +func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor { + // If the cache didn't help us with required transformations, do them here + if !cacheConfigApplied { + cacheConfig := t.b.CacheConfig() + + // Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller + + if cacheConfig.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + if mask != nil { + padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1) + if padSize > 0 { + mask = mask.Pad(ctx, 0, padSize, 0, 0) + } + + if mask.DType() != cacheConfig.MaskDType { + mask = mask.Cast(ctx, cacheConfig.MaskDType) + } + } + } + var kqMask *C.struct_ggml_tensor if mask != nil { kqMask = mask.(*Tensor).t diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 123ae537..e495e1f6 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla key, value, mask = cache.Get(ctx) } - // Only use the fast SDPA implementation if we have a cache, since that's what - // will do any expected backend-specific transformations for us - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { - return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale) + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + cacheConfigApplied := cache != nil + return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied) } else { query = query.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)