mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine If the user does not specify fa in the environment, use auto-mode. * review comments * ensure kv cache quantized types have FA explicitly enabled additional review comments
This commit is contained in:
@@ -74,7 +74,7 @@ type BackendParams struct {
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
FlashAttention FlashAttentionType
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
@@ -109,7 +109,7 @@ type Backend struct {
|
||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
|
||||
|
||||
flashAttention bool
|
||||
flashAttention ml.FlashAttentionType
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
maxGraphNodes int
|
||||
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention {
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
|
||||
26
ml/device.go
26
ml/device.go
@@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type FlashAttentionType int32
|
||||
|
||||
const (
|
||||
// Aligned with llama_flash_attn_type
|
||||
FlashAttentionAuto FlashAttentionType = -1
|
||||
FlashAttentionDisabled FlashAttentionType = 0
|
||||
FlashAttentionEnabled FlashAttentionType = 1
|
||||
)
|
||||
|
||||
func (f FlashAttentionType) LogValue() slog.Value {
|
||||
return slog.AnyValue(f.String())
|
||||
}
|
||||
|
||||
func (f FlashAttentionType) String() string {
|
||||
switch f {
|
||||
case FlashAttentionAuto:
|
||||
return "Auto"
|
||||
case FlashAttentionDisabled:
|
||||
return "Disabled"
|
||||
case FlashAttentionEnabled:
|
||||
return "Enabled"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
|
||||
Reference in New Issue
Block a user