mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-23 15:08:27 +00:00
refactor rope
change to a flatter directory structure and group the options with the function update models to call rope in one place
This commit is contained in:
committed by
Michael Yang
parent
e082d60a24
commit
603ceefaa6
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -95,7 +94,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
@@ -256,14 +255,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase)
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase)
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
@@ -330,6 +329,10 @@ func (o *TextOptions) isLocal(i int) bool {
|
||||
return o.slidingWindowPattern[i]
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||
|
||||
Reference in New Issue
Block a user