flash attn: add auto mode for llama engine (#13052)

* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
This commit is contained in:
Daniel Hiltgen
2025-12-12 13:27:19 -08:00
committed by GitHub
parent 3af5d3b738
commit bd6c1d6b49
7 changed files with 101 additions and 25 deletions

View File

@@ -118,7 +118,7 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize * numSeqMax)
@@ -127,10 +127,13 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
if flashAttention {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
} else {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
switch flashAttention {
case ml.FlashAttentionEnabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
case ml.FlashAttentionDisabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
case ml.FlashAttentionAuto:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
}
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))