mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
fix(qwen3): deepseek distill
deepseek's qwen3 distill uses a different rope scheme so support both
This commit is contained in:
committed by
Michael Yang
parent
6544e14735
commit
6c833d5f8d
@@ -15,11 +15,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize,
|
||||||
eps float32
|
numHeads,
|
||||||
ropeBase, ropeScale float32
|
numKVHeads,
|
||||||
|
keyLength,
|
||||||
|
valueLength int
|
||||||
|
|
||||||
keyLength, valueLength int
|
eps,
|
||||||
|
ropeBase,
|
||||||
|
ropeScale float32
|
||||||
|
ropeType string
|
||||||
|
originalContextLength int
|
||||||
|
|
||||||
numExperts, numExpertsUsed int
|
numExperts, numExpertsUsed int
|
||||||
normTopKProb bool
|
normTopKProb bool
|
||||||
@@ -29,6 +35,19 @@ func (o Options) headDim() int {
|
|||||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||||
|
if o.ropeType == "yarn" {
|
||||||
|
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||||
|
opts = append(opts,
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithExtrapolationFactor(1.),
|
||||||
|
rope.WithAttentionFactor(attnFactor),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
@@ -52,8 +71,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
@@ -183,7 +202,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Model = (*Model)(nil)
|
var _ model.Model = (*Model)(nil)
|
||||||
@@ -216,17 +235,19 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
),
|
),
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
Options: &Options{
|
Options: &Options{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
keyLength: int(c.Uint("attention.key_length")),
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
valueLength: int(c.Uint("attention.value_length")),
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeType: c.String("rope.scaling.type"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
numExperts: int(c.Uint("expert_count")),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user