Add deepseek v3.1 (#13063)

* Add mla for flash attention
* Revert to using chunks
This commit is contained in:
Grace
2025-11-17 18:03:21 -08:00
committed by GitHub
parent 1fd4cb87b2
commit 584e2d646f
4 changed files with 67 additions and 24 deletions

View File

@@ -230,7 +230,7 @@ type Tensor interface {
// kqv := value.Mulmat(ctx, kq) // kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface { type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
} }
type number interface { type number interface {

View File

@@ -1625,7 +1625,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
} }
} }
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {
kqMask = mask.(*Tensor).t kqMask = mask.(*Tensor).t
@@ -1642,6 +1642,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t) C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
} }
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = cur.Contiguous(ctx)
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv} return &Tensor{b: t.b, t: kqv}
} else { } else {
kq := key.MulmatFullPrec(ctx, query) kq := key.MulmatFullPrec(ctx, query)
@@ -1654,6 +1664,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
} }
kqv := value.Mulmat(ctx, kq) kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
} }
} }

View File

@@ -22,10 +22,14 @@ import (
// //
// Attention output with shape [d_v, heads, seq_len_q] // Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithSinks(ctx, query, key, value, nil, scale, cache) return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
} }
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query) ctx.Forward(query)
if key != nil && value != nil { if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) { if query.Dim(0) != key.Dim(0) {
@@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
// Only use the fast SDPA implementation if we have a cache, since that's what // Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us // will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale) return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
} else { } else {
query = query.Permute(ctx, 0, 2, 1, 3) query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)
@@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
kq = kq.Softmax(ctx) kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq) kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
} }
} }

View File

@@ -3,6 +3,7 @@ package deepseek2
// uses deepseek 2 architecture but written based on deepseek 3 model // uses deepseek 2 architecture but written based on deepseek 3 model
import ( import (
"cmp"
"math" "math"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
@@ -16,6 +17,7 @@ import (
) )
type Options struct { type Options struct {
isMLA bool
numExpertsUsed int numExpertsUsed int
numExperts int numExperts int
normTopKProb bool normTopKProb bool
@@ -32,8 +34,6 @@ type Options struct {
hiddenSize, hiddenSize,
numHeads, numHeads,
numKVHeads, numKVHeads,
keyLength,
valueLength,
originalContextLength int originalContextLength int
eps, eps,
@@ -62,6 +62,9 @@ type Attention struct {
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"` KVB *nn.Linear `gguf:"attn_kv_b"`
KB *nn.Linear `gguf:"attn_k_b"`
VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"` Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
} }
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
seqLength := hiddenStates.Dim(1) seqLength := hiddenStates.Dim(1)
var query ml.Tensor var query ml.Tensor
if opts.qLoraRank == 0 { // nil { if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates) query = attn.Q.Forward(ctx, hiddenStates)
} else { } else {
query = attn.QA.Forward(ctx, hiddenStates) query = attn.QA.Forward(ctx, hiddenStates)
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
compressedKV.Stride(1), compressedKV.Dim(1), compressedKV.Stride(1), compressedKV.Dim(1),
) )
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) var attention ml.Tensor
query = qRot.Concat(ctx, queryChunks[0], 0) if !opts.isMLA { // v3
key := kRot.Concat(ctx, kvChunks[0], 0) kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
}
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention) return attn.Output.Forward(ctx, attention)
} }
@@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) {
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor")))) mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length"))) kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{ &model.Vocabulary{
@@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) {
), ),
Layers: layers, Layers: layers,
Options: &Options{ Options: &Options{
isMLA: isMLA,
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.scaling.factor", 1),
@@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) {
numExpertsUsed: int(c.Uint("expert_used_count")), numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true), normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal, qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")), kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: int(c.Uint("attention.key_length")), qkHeadDim: keyLength,
vHeadDim: int(c.Uint("attention.value_length")), vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")), qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"), routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")), originalContextLength: int(c.Uint("rope.scaling.original_context_length")),