mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine If the user does not specify fa in the environment, use auto-mode. * review comments * ensure kv cache quantized types have FA explicitly enabled additional review comments
This commit is contained in:
@@ -109,7 +109,7 @@ type Backend struct {
|
||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
|
||||
|
||||
flashAttention bool
|
||||
flashAttention ml.FlashAttentionType
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
maxGraphNodes int
|
||||
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention {
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
|
||||
Reference in New Issue
Block a user