ggml: Enable flash attention for vision encoders

Although the vision component of multimodal models typically already
call the optimized nn.Attention, it is converted into non-fused
operations. That is because the backend-specific fused kernels may
have requirements, such as padding, and they is performed by the
cache, which vision encoders don't use.

This implements a fallback path in the backend, softening the
requirements into optimizations. In turn, this allows flash attention
to be used for vision encoders, saving a significant amount of VRAM
and improving performance.
This commit is contained in:
Jesse Gross
2025-12-02 15:39:27 -08:00
committed by Jesse Gross
parent 7837a5bc7e
commit 1108d8b34e
3 changed files with 29 additions and 6 deletions

View File

@@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
// If the cache didn't help us with required transformations, do them here
if !cacheConfigApplied {
cacheConfig := t.b.CacheConfig()
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
if cacheConfig.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if mask != nil {
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
if padSize > 0 {
mask = mask.Pad(ctx, 0, padSize, 0, 0)
}
if mask.DType() != cacheConfig.MaskDType {
mask = mask.Cast(ctx, cacheConfig.MaskDType)
}
}
}
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t