model: add rnj-1 inference support (#13354)

This commit is contained in:
Jeffrey Morgan
2025-12-08 16:49:17 -08:00
committed by GitHub
parent 603ceefaa6
commit d2f334c1f7
6 changed files with 208 additions and 69 deletions

View File

@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
model.SentencePiece
model.TextProcessor
*VisionModel `gguf:"v"`
*TextModel
@@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
@@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenState = m.Output.Forward(ctx, hiddenState)
if m.TextConfig.finalLogitSoftcap > 0.0 {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
}
return hiddenState, nil
}
func init() {

View File

@@ -2,6 +2,7 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -15,12 +16,32 @@ type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
ropeLocalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel {
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: 1,
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
@@ -79,13 +114,26 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase
}
// Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use
// the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase
}
// default to global rope base
return opts.ropeBase
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
ropeBase := opts.ropeBaseForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
@@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase
}
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
}
type TextMLP struct {
@@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
}