gemma2 impl

This commit is contained in:
Patrick Devine
2025-02-07 15:58:15 -08:00
committed by Michael Yang
parent 4dcf80167a
commit 5f74d1fd47
18 changed files with 1057 additions and 24 deletions

View File

@@ -0,0 +1,206 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
largeModelScaling bool
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
const (
gemma27BLayerCount = 46
)
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true
}
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@@ -0,0 +1,74 @@
package gemma3
import (
"fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
//*VisionModel `gguf:"v,vision"`
*TextModel
//Projector *nn.Linear `gguf:"mm.0"`
ImageProcessor
}
func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
}
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
//VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
}

View File

@@ -0,0 +1,193 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
finalLogitSoftcap float32
largeModelScaling bool
}
type TextModel struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
const (
gemma27BLayerCount = 46
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
func newTextModel(c ml.Config) *TextModel {
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]TextLayer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("text.attention.layer_norm_rms_epsilon"),
ropeLocalBase: c.Float("text.rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("text.rope.global.freq_base", 1000000.0),
ropeScale: c.Float("text.rope.freq_scale", 1.0),
finalLogitSoftcap: c.Float("text.final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%6 == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%6 == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true
}
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%6 == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs)
}

View File

@@ -0,0 +1,57 @@
package gemma3
import (
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals []float32
bounds := img.Bounds()
var rVals, gVals, bVals []float32
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

View File

@@ -76,14 +76,15 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {

View File

@@ -20,14 +20,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil

View File

@@ -1,6 +1,8 @@
package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -18,6 +18,15 @@ const (
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
@@ -27,7 +36,7 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Scores []uint32
Scores []float32
Merges []string
BOS, EOS int32
@@ -76,7 +85,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == 3 {
if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}

221
model/process_text_spm.go Normal file
View File

@@ -0,0 +1,221 @@
package model
import (
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
slog.Debug("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
slog.Debug("encoded", "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -0,0 +1,110 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t))
default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePieceModel(preTokenizer, &v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}

BIN
model/testdata/gemma2/tokenizer.model vendored Normal file

Binary file not shown.