mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
Add deepseek v3.1 (#13063)
* Add mla for flash attention * Revert to using chunks
This commit is contained in:
@@ -3,6 +3,7 @@ package deepseek2
|
||||
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
isMLA bool
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
@@ -32,8 +34,6 @@ type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength,
|
||||
originalContextLength int
|
||||
|
||||
eps,
|
||||
@@ -62,6 +62,9 @@ type Attention struct {
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 { // nil {
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
var attention ml.Tensor
|
||||
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
if !opts.isMLA { // v3
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
@@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||
|
||||
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
|
||||
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
|
||||
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
@@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
isMLA: isMLA,
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
@@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: int(c.Uint("attention.key_length")),
|
||||
vHeadDim: int(c.Uint("attention.value_length")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
Reference in New Issue
Block a user