mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
model: fix global layer rope scale values for gemma 3 (#13452)
This commit is contained in:
@@ -28,10 +28,10 @@ type TextConfig struct {
|
|||||||
finalLogitSoftcap float32
|
finalLogitSoftcap float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
||||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||||
if o.ropeType == "yarn" {
|
if o.ropeType == "yarn" {
|
||||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
|
||||||
ropeOpts = append(ropeOpts,
|
ropeOpts = append(ropeOpts,
|
||||||
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||||
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@@ -83,7 +83,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
ropeScale: c.Float("rope.scaling.factor", 8.0),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -117,31 +117,31 @@ type TextSelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
|
||||||
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||||
return opts.ropeLocalBase
|
return opts.ropeLocalBase, 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Standard Gemma3: only every n-th layer is global,
|
// Standard Gemma3: only every n-th layer is global,
|
||||||
// where n = gemmaGlobalCacheCount, otherwise use
|
// where n = gemmaGlobalCacheCount, otherwise use
|
||||||
// the local rope base
|
// the local rope base
|
||||||
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||||
return opts.ropeLocalBase
|
return opts.ropeLocalBase, 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
// default to global rope base
|
// default to global rope base
|
||||||
return opts.ropeBase
|
return opts.ropeBase, opts.ropeScale
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
|
|
||||||
ropeBase := opts.ropeBaseForLayer(layer)
|
ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
|
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@@ -152,7 +152,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
|
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@@ -165,7 +165,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
|
||||||
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user