refactor rope

change to a flatter directory structure and group the options with the
function

update models to call rope in one place
This commit is contained in:
Michael Yang
2025-11-18 15:17:03 -08:00
committed by Michael Yang
parent e082d60a24
commit 603ceefaa6
21 changed files with 114 additions and 91 deletions

View File

@@ -8,7 +8,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
@@ -26,11 +25,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -44,8 +43,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil
}
return key, nil
@@ -206,6 +205,10 @@ type TextModelOptions struct {
crossAttentionLayers []int32
}
func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Transformer *TextDecoder `gguf:"blk"`