fix(qwen3): deepseek distill

deepseek's qwen3 distill uses a different rope scheme so support both
This commit is contained in:
Michael Yang
2025-10-13 12:09:53 -07:00
committed by Michael Yang
parent 6544e14735
commit 6c833d5f8d

View File

@@ -15,11 +15,17 @@ import (
) )
type Options struct { type Options struct {
hiddenSize, numHeads, numKVHeads int hiddenSize,
eps float32 numHeads,
ropeBase, ropeScale float32 numKVHeads,
keyLength,
valueLength int
keyLength, valueLength int eps,
ropeBase,
ropeScale float32
ropeType string
originalContextLength int
numExperts, numExpertsUsed int numExperts, numExpertsUsed int
normTopKProb bool normTopKProb bool
@@ -29,6 +35,19 @@ func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
} }
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
}
type Attention struct { type Attention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
@@ -52,8 +71,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps) query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@@ -183,7 +202,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
} }
var _ model.Model = (*Model)(nil) var _ model.Model = (*Model)(nil)
@@ -216,17 +235,19 @@ func New(c fs.Config) (model.Model, error) {
), ),
Layers: layers, Layers: layers,
Options: &Options{ Options: &Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")), keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")), valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeType: c.String("rope.scaling.type"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeBase: c.Float("rope.freq_base"),
numExperts: int(c.Uint("expert_count")), ropeScale: c.Float("rope.scaling.factor", 1),
numExpertsUsed: int(c.Uint("expert_used_count")), originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
normTopKProb: c.Bool("norm_top_k_prob", true), numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
}, },
} }