diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index cc58e4a2..9fd6e313 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -15,11 +15,17 @@ import ( ) type Options struct { - hiddenSize, numHeads, numKVHeads int - eps float32 - ropeBase, ropeScale float32 + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int - keyLength, valueLength int + eps, + ropeBase, + ropeScale float32 + ropeType string + originalContextLength int numExperts, numExpertsUsed int normTopKProb bool @@ -29,6 +35,19 @@ func (o Options) headDim() int { return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + opts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + ) + } + return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) +} + type Attention struct { Query *nn.Linear `gguf:"attn_q"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` @@ -52,8 +71,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -183,7 +202,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil } var _ model.Model = (*Model)(nil) @@ -216,17 +235,19 @@ func New(c fs.Config) (model.Model, error) { ), Layers: layers, Options: &Options{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - keyLength: int(c.Uint("attention.key_length")), - valueLength: int(c.Uint("attention.value_length")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.scaling.factor", 1), - numExperts: int(c.Uint("expert_count")), - numExpertsUsed: int(c.Uint("expert_used_count")), - normTopKProb: c.Bool("norm_top_k_prob", true), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeType: c.String("rope.scaling.type"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), }, }