From 584e2d646fb4d2f1643b4da81a096d01114f5b2b Mon Sep 17 00:00:00 2001 From: Grace <88872231+gr4ceG@users.noreply.github.com> Date: Mon, 17 Nov 2025 18:03:21 -0800 Subject: [PATCH] Add deepseek v3.1 (#13063) * Add mla for flash attention * Revert to using chunks --- ml/backend.go | 2 +- ml/backend/ggml/ggml.go | 16 ++++++++- ml/nn/attention.go | 13 +++++-- model/models/deepseek2/model.go | 60 ++++++++++++++++++++++----------- 4 files changed, 67 insertions(+), 24 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index 36557e62..99d6b146 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -230,7 +230,7 @@ type Tensor interface { // kqv := value.Mulmat(ctx, kq) // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) type ScaledDotProductAttention interface { - ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor + ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor } type number interface { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2aa72190..92ed4413 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1625,7 +1625,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor { +func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { var kqMask *C.struct_ggml_tensor if mask != nil { kqMask = mask.(*Tensor).t @@ -1642,6 +1642,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t) } C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) + + if vmla != nil { + var cur ml.Tensor = &Tensor{b: t.b, t: kqv} + cur = cur.Permute(ctx, 0, 2, 1, 3) + cur = vmla.Mulmat(ctx, cur) + cur = cur.Permute(ctx, 0, 2, 1, 3) + cur = cur.Contiguous(ctx) + kqv = cur.(*Tensor).t + } + return &Tensor{b: t.b, t: kqv} } else { kq := key.MulmatFullPrec(ctx, query) @@ -1654,6 +1664,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin } kqv := value.Mulmat(ctx, kq) + if vmla != nil { + kqv = vmla.Mulmat(ctx, kqv) + } + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } } diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 94dbde0b..123ae537 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -22,10 +22,14 @@ import ( // // Attention output with shape [d_v, heads, seq_len_q] func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithSinks(ctx, query, key, value, nil, scale, cache) + return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache) } func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache) +} + +func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { ctx.Forward(query) if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { @@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal // Only use the fast SDPA implementation if we have a cache, since that's what // will do any expected backend-specific transformations for us if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { - return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale) + return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale) } else { query = query.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) @@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal kq = kq.Softmax(ctx) kqv := value.Mulmat(ctx, kq) + + if vmla != nil { + kqv = vmla.Mulmat(ctx, kqv) + } + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } } diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 68b12cd9..c1251ecf 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -3,6 +3,7 @@ package deepseek2 // uses deepseek 2 architecture but written based on deepseek 3 model import ( + "cmp" "math" "github.com/ollama/ollama/fs" @@ -16,6 +17,7 @@ import ( ) type Options struct { + isMLA bool numExpertsUsed int numExperts int normTopKProb bool @@ -32,8 +34,6 @@ type Options struct { hiddenSize, numHeads, numKVHeads, - keyLength, - valueLength, originalContextLength int eps, @@ -62,6 +62,9 @@ type Attention struct { KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` KVB *nn.Linear `gguf:"attn_kv_b"` + KB *nn.Linear `gguf:"attn_k_b"` + VB *nn.Linear `gguf:"attn_v_b"` + Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } @@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor seqLength := hiddenStates.Dim(1) var query ml.Tensor - if opts.qLoraRank == 0 { // nil { + if opts.qLoraRank == 0 { query = attn.Q.Forward(ctx, hiddenStates) } else { query = attn.QA.Forward(ctx, hiddenStates) @@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor compressedKV.Stride(1), compressedKV.Dim(1), ) - kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) - kPass = attn.KVB.Forward(ctx, kPass) - - kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) - kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) - qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) - kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) + var attention ml.Tensor - query = qRot.Concat(ctx, queryChunks[0], 0) - key := kRot.Concat(ctx, kvChunks[0], 0) + if !opts.isMLA { // v3 + kPass = attn.KVB.Forward(ctx, kPass) + + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) + + kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) + query = qRot.Concat(ctx, queryChunks[0], 0) + key := kRot.Concat(ctx, kvChunks[0], 0) + attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) + } else { // v3.1 + qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) + qPassAbsorb := attn.KB.Forward(ctx, qPass) + qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3) + + query = qRot.Concat(ctx, qPassAbsorb, 0) + kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength) + key := kRot.Concat(ctx, kPass, 0) + value := kPass + + attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) + } - attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) return attn.Output.Forward(ctx, attention) } @@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) { mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor")))) kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length"))) + isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0 + keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length"))) + valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length"))) + m := Model{ BytePairEncoding: model.NewBytePairEncoding( &model.Vocabulary{ @@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) { ), Layers: layers, Options: &Options{ + isMLA: isMLA, hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), - keyLength: int(c.Uint("attention.key_length")), - valueLength: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.scaling.factor", 1), @@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) { numExpertsUsed: int(c.Uint("expert_used_count")), normTopKProb: c.Bool("expert_weights_norm", true), - qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal, + qLoraRank: int(c.Uint("attention.q_lora_rank")), kvLoraRank: int(c.Uint("attention.kv_lora_rank")), - qkHeadDim: int(c.Uint("attention.key_length")), - vHeadDim: int(c.Uint("attention.value_length")), + qkHeadDim: keyLength, + vHeadDim: valueLength, qkRopeHeadDim: int(c.Uint("rope.dimension_count")), - qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), - kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")), + kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")), routedScalingFactor: c.Float("expert_weights_scale"), originalContextLength: int(c.Uint("rope.scaling.original_context_length")),