interleaved mrope (#12807)

* ml(ggml): mrope
* interleave mrope
This commit is contained in:
Michael Yang
2025-10-30 11:29:00 -07:00
committed by GitHub
parent 75e75d9afe
commit f67a6df110
10 changed files with 209 additions and 119 deletions

View File

@@ -112,7 +112,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positionSlice := slices.Collect(makeSlice2D[int32](3, len(batch.Positions)))
// ggml mrope requires 4 positions per token: [time, height, width, extra]
positionSlice := slices.Collect(makeSlice2D[int32](4, len(batch.Positions)))
for i, id := range batch.Positions {
if id < int32(len(m.positionCache)) {
id = m.positionCache[id]
@@ -123,6 +124,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positionSlice[0][i] = id
positionSlice[1][i] = id
positionSlice[2][i] = id
// positionSlice[3] is intentionally left as zeros
}
hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
@@ -147,8 +149,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0]), len(positionSlice))
cos, sin := m.rotaryEmbedding(ctx, positions)
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
for i, layer := range m.TextModel.Layers {
if m.Cache != nil {
m.Cache.SetLayer(i)
@@ -159,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, outputs, m.Cache, m.Options)
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
if i < len(deepstackVisualEmbeds) {
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
}
@@ -191,9 +192,10 @@ func New(c fs.Config) (model.Model, error) {
ImageProcessor: newImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, position ml.Tensor) (ml.Tensor, error) {
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
return nil, kvcache.ErrNotSupported
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil
})
return &m, nil
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
)
@@ -27,14 +29,18 @@ type TextOptions struct {
numExperts, numExpertsUsed int
normTopKProb bool
inverseFrequenciesCache []float32
}
func (o TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
rope.WithMRoPESections(o.mropeSections),
)
}
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
@@ -44,7 +50,7 @@ type TextAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := sa.Query.Forward(ctx, hiddenStates)
@@ -58,8 +64,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tenso
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@@ -125,10 +131,10 @@ type TextLayer struct {
TextMLP
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, cos, sin, cache, opts)
hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
@@ -153,42 +159,6 @@ type TextModel struct {
Options *TextOptions
}
func (m *TextModel) rotaryEmbedding(ctx ml.Context, positions ml.Tensor) (_, _ ml.Tensor) {
positions = positions.Reshape(ctx, 1, positions.Dim(0), positions.Dim(1))
if len(m.Options.inverseFrequenciesCache) == 0 {
m.Options.inverseFrequenciesCache = make([]float32, m.Options.headDim()/2)
for i := range m.Options.inverseFrequenciesCache {
frequency := float32(math.Pow(float64(m.Options.ropeBase), float64(i*2)/float64(m.Options.headDim())))
m.Options.inverseFrequenciesCache[i] = 1 / frequency
}
}
inverseFrequencies := ctx.Input().FromFloats(m.Options.inverseFrequenciesCache, 1, len(m.Options.inverseFrequenciesCache))
positions = positions.Cast(ctx, ml.DTypeF32)
frequencies := inverseFrequencies.Mulmat(ctx, positions)
interleaved := frequencies.View(ctx,
0, frequencies.Dim(0),
frequencies.Stride(1), frequencies.Dim(1),
)
for _, i := range []int{1, 2} {
args := []int{
i * frequencies.Stride(0), 1,
3 * frequencies.Stride(0), m.Options.mropeSections[i],
frequencies.Stride(1), frequencies.Dim(1),
}
ctx.Forward(frequencies.View(ctx, i*frequencies.Stride(2)+args[0], args[1:]...).
Copy(ctx, interleaved.View(ctx, args[0], args[1:]...)))
}
interleaved = interleaved.Concat(ctx, interleaved, 0)
interleaved = interleaved.Reshape(ctx, interleaved.Dim(0), 1, interleaved.Dim(1), interleaved.Dim(2))
return interleaved.Cos(ctx), interleaved.Sin(ctx)
}
var _ model.Model = (*Model)(nil)
func newTextModel(c fs.Config) *TextModel {