ggml: Enable flash attention for vision encoders

Although the vision component of multimodal models typically already
call the optimized nn.Attention, it is converted into non-fused
operations. That is because the backend-specific fused kernels may
have requirements, such as padding, and they is performed by the
cache, which vision encoders don't use.

This implements a fallback path in the backend, softening the
requirements into optimizations. In turn, this allows flash attention
to be used for vision encoders, saving a significant amount of VRAM
and improving performance.
This commit is contained in:
Jesse Gross
2025-12-02 15:39:27 -08:00
committed by Jesse Gross
parent 7837a5bc7e
commit 1108d8b34e
3 changed files with 29 additions and 6 deletions

View File

@@ -233,8 +233,10 @@ type Tensor interface {
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
//
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
}
type number interface {