mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
model: add rnj-1 inference support (#13354)
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
model.TextProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
@@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextConfig.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -2,6 +2,7 @@ package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@@ -15,12 +16,32 @@ type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
ropeLocalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
}
|
||||
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
rope.WithBetaFast(o.ropeBetaFast),
|
||||
rope.WithBetaSlow(o.ropeBetaSlow),
|
||||
)
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
m := TextModel{
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: 1,
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
// Google's Gemma 3 release with sliding window attention does
|
||||
// not use final logit softcapping, and so force it to 0.0
|
||||
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
||||
// model configuration instead of here, as future versions of
|
||||
// models may include both sliding window attention and final
|
||||
// logit softcapping.
|
||||
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
}
|
||||
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
@@ -79,13 +114,26 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
||||
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||
return opts.ropeLocalBase
|
||||
}
|
||||
|
||||
// Standard Gemma3: only every n-th layer is global,
|
||||
// where n = gemmaGlobalCacheCount, otherwise use
|
||||
// the local rope base
|
||||
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||
return opts.ropeLocalBase
|
||||
}
|
||||
|
||||
// default to global rope base
|
||||
return opts.ropeBase
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
|
||||
ropeBase := opts.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
}
|
||||
ropeBase := opts.ropeBaseForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
@@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user