mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
ggml: Enable flash attention for vision encoders
Although the vision component of multimodal models typically already call the optimized nn.Attention, it is converted into non-fused operations. That is because the backend-specific fused kernels may have requirements, such as padding, and they is performed by the cache, which vision encoders don't use. This implements a fallback path in the backend, softening the requirements into optimizations. In turn, this allows flash attention to be used for vision encoders, saving a significant amount of VRAM and improving performance.
This commit is contained in:
@@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
cacheConfigApplied := cache != nil
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
Reference in New Issue
Block a user