Add deepseek v3.1 (#13063)

* Add mla for flash attention
* Revert to using chunks
This commit is contained in:
Grace
2025-11-17 18:03:21 -08:00
committed by GitHub
parent 1fd4cb87b2
commit 584e2d646f
4 changed files with 67 additions and 24 deletions

View File

@@ -3,6 +3,7 @@ package deepseek2
// uses deepseek 2 architecture but written based on deepseek 3 model
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
@@ -16,6 +17,7 @@ import (
)
type Options struct {
isMLA bool
numExpertsUsed int
numExperts int
normTopKProb bool
@@ -32,8 +34,6 @@ type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
originalContextLength int
eps,
@@ -62,6 +62,9 @@ type Attention struct {
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
KB *nn.Linear `gguf:"attn_k_b"`
VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 { // nil {
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
compressedKV.Stride(1), compressedKV.Dim(1),
)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
var attention ml.Tensor
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
if !opts.isMLA { // v3
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
}
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
@@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) {
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
@@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) {
),
Layers: layers,
Options: &Options{
isMLA: isMLA,
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
@@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) {
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: int(c.Uint("attention.key_length")),
vHeadDim: int(c.Uint("attention.value_length")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),