mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
ggml: Enable flash attention for vision encoders
Although the vision component of multimodal models typically already call the optimized nn.Attention, it is converted into non-fused operations. That is because the backend-specific fused kernels may have requirements, such as padding, and they is performed by the cache, which vision encoders don't use. This implements a fallback path in the backend, softening the requirements into optimizations. In turn, this allows flash attention to be used for vision encoders, saving a significant amount of VRAM and improving performance.
This commit is contained in:
@@ -233,8 +233,10 @@ type Tensor interface {
|
|||||||
//
|
//
|
||||||
// kqv := value.Mulmat(ctx, kq)
|
// kqv := value.Mulmat(ctx, kq)
|
||||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
//
|
||||||
|
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
|
||||||
type ScaledDotProductAttention interface {
|
type ScaledDotProductAttention interface {
|
||||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type number interface {
|
type number interface {
|
||||||
|
|||||||
@@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
|
||||||
|
// If the cache didn't help us with required transformations, do them here
|
||||||
|
if !cacheConfigApplied {
|
||||||
|
cacheConfig := t.b.CacheConfig()
|
||||||
|
|
||||||
|
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
|
||||||
|
|
||||||
|
if cacheConfig.PermutedV {
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mask != nil {
|
||||||
|
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
|
||||||
|
if padSize > 0 {
|
||||||
|
mask = mask.Pad(ctx, 0, padSize, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mask.DType() != cacheConfig.MaskDType {
|
||||||
|
mask = mask.Cast(ctx, cacheConfig.MaskDType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var kqMask *C.struct_ggml_tensor
|
var kqMask *C.struct_ggml_tensor
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
kqMask = mask.(*Tensor).t
|
kqMask = mask.(*Tensor).t
|
||||||
|
|||||||
@@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
|
|||||||
key, value, mask = cache.Get(ctx)
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||||
// will do any expected backend-specific transformations for us
|
cacheConfigApplied := cache != nil
|
||||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
|
||||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
|
||||||
} else {
|
} else {
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user