mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
model: fix global layer rope scale values for gemma 3 (#13452)
This commit is contained in:
@@ -28,10 +28,10 @@ type TextConfig struct {
|
||||
finalLogitSoftcap float32
|
||||
}
|
||||
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
|
||||
)
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -83,7 +83,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 8.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
},
|
||||
}
|
||||
@@ -117,31 +117,31 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
||||
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
|
||||
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||
return opts.ropeLocalBase
|
||||
return opts.ropeLocalBase, 1.0
|
||||
}
|
||||
|
||||
// Standard Gemma3: only every n-th layer is global,
|
||||
// where n = gemmaGlobalCacheCount, otherwise use
|
||||
// the local rope base
|
||||
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||
return opts.ropeLocalBase
|
||||
return opts.ropeLocalBase, 1.0
|
||||
}
|
||||
|
||||
// default to global rope base
|
||||
return opts.ropeBase
|
||||
return opts.ropeBase, opts.ropeScale
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
|
||||
ropeBase := opts.ropeBaseForLayer(layer)
|
||||
ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -152,7 +152,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -165,7 +165,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
||||
ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
||||
Reference in New Issue
Block a user