diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 37e688d6..759cc6b3 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -28,10 +28,10 @@ type TextConfig struct { finalLogitSoftcap float32 } -func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { +func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor { ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} if o.ropeType == "yarn" { - attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale)))) ropeOpts = append(ropeOpts, rope.WithOriginalContextLength(o.ropeOriginalContext), rope.WithExtrapolationFactor(o.ropeExtrapolation), @@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi ) } - return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...) + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...) } type TextModel struct { @@ -83,7 +83,7 @@ func newTextModel(c fs.Config) *TextModel { ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), - ropeScale: c.Float("rope.scaling.factor", 1.0), + ropeScale: c.Float("rope.scaling.factor", 8.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } @@ -117,31 +117,31 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (opts *TextConfig) ropeBaseForLayer(layer int) float32 { +func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) { if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] { - return opts.ropeLocalBase + return opts.ropeLocalBase, 1.0 } // Standard Gemma3: only every n-th layer is global, // where n = gemmaGlobalCacheCount, otherwise use // the local rope base if (layer+1)%gemmaGlobalCacheCount > 0 { - return opts.ropeLocalBase + return opts.ropeLocalBase, 1.0 } // default to global rope base - return opts.ropeBase + return opts.ropeBase, opts.ropeScale } func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeBase := opts.ropeBaseForLayer(layer) + ropeBase, ropeScale := opts.ropeValuesForLayer(layer) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -152,7 +152,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -165,7 +165,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil + ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer) + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil } type TextMLP struct {