mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
ggml: Enable flash attention for vision encoders
Although the vision component of multimodal models typically already call the optimized nn.Attention, it is converted into non-fused operations. That is because the backend-specific fused kernels may have requirements, such as padding, and they is performed by the cache, which vision encoders don't use. This implements a fallback path in the backend, softening the requirements into optimizations. In turn, this allows flash attention to be used for vision encoders, saving a significant amount of VRAM and improving performance.
This commit is contained in:
@@ -233,8 +233,10 @@ type Tensor interface {
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
//
|
||||
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
|
||||
@@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
|
||||
// If the cache didn't help us with required transformations, do them here
|
||||
if !cacheConfigApplied {
|
||||
cacheConfig := t.b.CacheConfig()
|
||||
|
||||
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
|
||||
|
||||
if cacheConfig.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
if mask != nil {
|
||||
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
|
||||
if padSize > 0 {
|
||||
mask = mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
if mask.DType() != cacheConfig.MaskDType {
|
||||
mask = mask.Cast(ctx, cacheConfig.MaskDType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
|
||||
@@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
cacheConfigApplied := cache != nil
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
Reference in New Issue
Block a user