flash attn: add auto mode for llama engine (#13052)

* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
This commit is contained in:
Daniel Hiltgen
2025-12-12 13:27:19 -08:00
committed by GitHub
parent 3af5d3b738
commit bd6c1d6b49
7 changed files with 101 additions and 25 deletions

View File

@@ -74,7 +74,7 @@ type BackendParams struct {
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
FlashAttention FlashAttentionType
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))

View File

@@ -109,7 +109,7 @@ type Backend struct {
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
flashAttention bool
flashAttention ml.FlashAttentionType
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
if b.flashAttention == ml.FlashAttentionEnabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention {
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
return true
}
type FlashAttentionType int32
const (
// Aligned with llama_flash_attn_type
FlashAttentionAuto FlashAttentionType = -1
FlashAttentionDisabled FlashAttentionType = 0
FlashAttentionEnabled FlashAttentionType = 1
)
func (f FlashAttentionType) LogValue() slog.Value {
return slog.AnyValue(f.String())
}
func (f FlashAttentionType) String() string {
switch f {
case FlashAttentionAuto:
return "Auto"
case FlashAttentionDisabled:
return "Disabled"
case FlashAttentionEnabled:
return "Enabled"
default:
return "unknown"
}
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices