mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Add deepseek v3.1 (#13063)
* Add mla for flash attention * Revert to using chunks
This commit is contained in:
@@ -230,7 +230,7 @@ type Tensor interface {
|
|||||||
// kqv := value.Mulmat(ctx, kq)
|
// kqv := value.Mulmat(ctx, kq)
|
||||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
type ScaledDotProductAttention interface {
|
type ScaledDotProductAttention interface {
|
||||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor
|
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type number interface {
|
type number interface {
|
||||||
|
|||||||
@@ -1625,7 +1625,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
|
||||||
var kqMask *C.struct_ggml_tensor
|
var kqMask *C.struct_ggml_tensor
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
kqMask = mask.(*Tensor).t
|
kqMask = mask.(*Tensor).t
|
||||||
@@ -1642,6 +1642,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
|||||||
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
||||||
}
|
}
|
||||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||||
|
|
||||||
|
if vmla != nil {
|
||||||
|
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
|
||||||
|
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
cur = vmla.Mulmat(ctx, cur)
|
||||||
|
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
cur = cur.Contiguous(ctx)
|
||||||
|
kqv = cur.(*Tensor).t
|
||||||
|
}
|
||||||
|
|
||||||
return &Tensor{b: t.b, t: kqv}
|
return &Tensor{b: t.b, t: kqv}
|
||||||
} else {
|
} else {
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
@@ -1654,6 +1664,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
|||||||
}
|
}
|
||||||
|
|
||||||
kqv := value.Mulmat(ctx, kq)
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
if vmla != nil {
|
||||||
|
kqv = vmla.Mulmat(ctx, kqv)
|
||||||
|
}
|
||||||
|
|
||||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,10 +22,14 @@ import (
|
|||||||
//
|
//
|
||||||
// Attention output with shape [d_v, heads, seq_len_q]
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
return AttentionWithSinks(ctx, query, key, value, nil, scale, cache)
|
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
ctx.Forward(query)
|
ctx.Forward(query)
|
||||||
if key != nil && value != nil {
|
if key != nil && value != nil {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
@@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
|||||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||||
// will do any expected backend-specific transformations for us
|
// will do any expected backend-specific transformations for us
|
||||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale)
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
||||||
} else {
|
} else {
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
@@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
|||||||
kq = kq.Softmax(ctx)
|
kq = kq.Softmax(ctx)
|
||||||
|
|
||||||
kqv := value.Mulmat(ctx, kq)
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
|
||||||
|
if vmla != nil {
|
||||||
|
kqv = vmla.Mulmat(ctx, kqv)
|
||||||
|
}
|
||||||
|
|
||||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package deepseek2
|
|||||||
// uses deepseek 2 architecture but written based on deepseek 3 model
|
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
isMLA bool
|
||||||
numExpertsUsed int
|
numExpertsUsed int
|
||||||
numExperts int
|
numExperts int
|
||||||
normTopKProb bool
|
normTopKProb bool
|
||||||
@@ -32,8 +34,6 @@ type Options struct {
|
|||||||
hiddenSize,
|
hiddenSize,
|
||||||
numHeads,
|
numHeads,
|
||||||
numKVHeads,
|
numKVHeads,
|
||||||
keyLength,
|
|
||||||
valueLength,
|
|
||||||
originalContextLength int
|
originalContextLength int
|
||||||
|
|
||||||
eps,
|
eps,
|
||||||
@@ -62,6 +62,9 @@ type Attention struct {
|
|||||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||||
|
|
||||||
|
KB *nn.Linear `gguf:"attn_k_b"`
|
||||||
|
VB *nn.Linear `gguf:"attn_v_b"`
|
||||||
|
|
||||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||||||
seqLength := hiddenStates.Dim(1)
|
seqLength := hiddenStates.Dim(1)
|
||||||
|
|
||||||
var query ml.Tensor
|
var query ml.Tensor
|
||||||
if opts.qLoraRank == 0 { // nil {
|
if opts.qLoraRank == 0 {
|
||||||
query = attn.Q.Forward(ctx, hiddenStates)
|
query = attn.Q.Forward(ctx, hiddenStates)
|
||||||
} else {
|
} else {
|
||||||
query = attn.QA.Forward(ctx, hiddenStates)
|
query = attn.QA.Forward(ctx, hiddenStates)
|
||||||
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
|
||||||
kPass = attn.KVB.Forward(ctx, kPass)
|
|
||||||
|
|
||||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
|
||||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
|
||||||
|
|
||||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
|
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||||
|
|
||||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
var attention ml.Tensor
|
||||||
|
|
||||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
if !opts.isMLA { // v3
|
||||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
kPass = attn.KVB.Forward(ctx, kPass)
|
||||||
|
|
||||||
|
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||||
|
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||||
|
|
||||||
|
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||||
|
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||||
|
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||||
|
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||||
|
} else { // v3.1
|
||||||
|
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||||
|
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||||
|
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||||
|
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||||
|
key := kRot.Concat(ctx, kPass, 0)
|
||||||
|
value := kPass
|
||||||
|
|
||||||
|
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||||
|
}
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||||
return attn.Output.Forward(ctx, attention)
|
return attn.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
@@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||||
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||||
|
|
||||||
|
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
|
||||||
|
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
|
||||||
|
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
@@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
),
|
),
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
Options: &Options{
|
Options: &Options{
|
||||||
|
isMLA: isMLA,
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
keyLength: int(c.Uint("attention.key_length")),
|
|
||||||
valueLength: int(c.Uint("attention.value_length")),
|
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
@@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||||
|
|
||||||
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||||
qkHeadDim: int(c.Uint("attention.key_length")),
|
qkHeadDim: keyLength,
|
||||||
vHeadDim: int(c.Uint("attention.value_length")),
|
vHeadDim: valueLength,
|
||||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||||
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||||
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||||
|
|
||||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
|
|||||||
Reference in New Issue
Block a user