deepseekocr

This commit is contained in:
Michael Yang
2025-10-31 19:15:32 -07:00
committed by Michael Yang
parent 8ed1adf3db
commit 92981ae3f2
14 changed files with 975 additions and 7 deletions

View File

@@ -0,0 +1,83 @@
package deepseekocr
import (
"bytes"
"image"
"image/color"
"math"
"slices"
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ratio struct {
x, y int
}
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
img, _, err := image.Decode(bytes.NewReader(bts))
if err != nil {
return nil, nil, nil, err
}
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
var targetRatios []ratio
for n := minNum; n <= maxNum; n++ {
for i := 1; i <= n; i++ {
for j := 1; j <= n; j++ {
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
targetRatios = append(targetRatios, ratio{i, j})
}
}
}
}
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
blocks := targetRatio.x * targetRatio.y
mean := imageproc.ImageNetStandardMean
std := imageproc.ImageNetStandardSTD
var patches []float32
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
for i := range blocks {
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
draw.Draw(patch, patch.Bounds(), resized, image.Point{
X: i % (targetWidth / imageSize) * imageSize,
Y: i / (targetWidth / imageSize) * imageSize,
}, draw.Over)
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
}
img = imageproc.CompositeColor(img, color.Gray{})
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
[]int{targetRatio.x, targetRatio.y},
nil
}
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
bestDiff := math.MaxFloat64
best := ratio{1, 1}
realRatio := float64(width) / float64(height)
for _, target := range targetRatios {
targetRatio := float64(target.x) / float64(target.y)
diff := math.Abs(realRatio - targetRatio)
if diff < bestDiff {
bestDiff = diff
best = target
} else if diff == bestDiff {
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
best = target
}
}
}
return best
}

View File

@@ -0,0 +1,192 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
Text *textModel
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
Projector *nn.Linear `gguf:"mm.layers"`
}
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
patches, original, crop, err := ProcessImage(ctx, bts)
if err != nil {
return nil, err
}
var outputs []ml.Tensor
if true { // TODO: local features if sum(patches) != 0
samOutputs := m.Sam.Forward(ctx, patches)
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
localOutputs = m.Projector.Forward(ctx, localOutputs)
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
outputs = append(outputs, localOutputs)
}
samOutputs := m.Sam.Forward(ctx, original)
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
outputs = append(outputs, globalOutputs, m.ViewSeperator)
return []input.Multimodal{
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
outputs := make([]*input.Input, 0, len(inputs))
for i := range inputs {
if inputs[i].Multimodal == nil {
outputs = append(outputs, inputs[i])
continue
}
t := inputs[i].Multimodal[0].Tensor
outputs = append(outputs, &input.Input{
Token: 128815,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1) - 1,
})
outputs = slices.Grow(outputs, t.Dim(1)-1)
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
}
return outputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, mm := range batch.Multimodal {
t := mm.Multimodal[0].Tensor
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
}
hiddenStates := inputsEmbeds
for i, block := range m.Text.Blocks {
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Text.Blocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
}
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
return m.Text.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
textBlocks := make([]textBlock, c.Uint("block_count"))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
for i := range textBlocks {
if i >= leadingDenseBlockCount {
textBlocks[i].FeedForward = &textMoe{}
} else {
textBlocks[i].FeedForward = &textMLP{}
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Text: &textModel{
Blocks: textBlocks,
Options: textOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeBase: c.Float("rope.freq_base", 10_000),
ropeScale: c.Float("rope.scaling.factor", 1.0),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
},
},
Vision: &visionModel{
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
Options: visionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.head_count")),
imageSize: int(c.Uint("vision.image_size", 224)),
patchSize: int(c.Uint("vision.patch_size", 14)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
},
},
Sam: &samModel{
Blocks: make([]samBlock, c.Uint("sam.block_count")),
Options: samOptions{
hiddenSize: int(c.Uint("sam.embedding_length")),
numHeads: int(c.Uint("sam.head_count")),
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
},
},
}
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
return &m, nil
})
}

View File

@@ -0,0 +1,225 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type samModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd"`
Blocks []samBlock `gguf:"blk"`
Neck *samNeck `gguf:"neck"`
Net2 *nn.Conv2D `gguf:"net_2"`
Net3 *nn.Conv2D `gguf:"net_3"`
Options samOptions
}
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
source := m.PositionEmbedding.Dim(1)
target := hiddenStates.Dim(2)
if source != target {
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
return m.PositionEmbedding
}
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
}
for i, block := range m.Blocks {
var windowSize int
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
windowSize = 14
}
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
}
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
return hiddenStates
}
type samOptions struct {
hiddenSize,
numHeads int
eps float32
globalAttentionLayers []int32
}
func (o samOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type samBlock struct {
Norm1 *nn.LayerNorm `gguf:"norm1"`
Attention *samAttention `gguf:"attn"`
Norm2 *nn.LayerNorm `gguf:"norm2"`
FeedForward *samMLP `gguf:"mlp"`
}
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
var pw, ph int
if windowSize > 0 {
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
if pw > 0 || ph > 0 {
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
}
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
}
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
if windowSize > 0 {
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type samAttention struct {
QKV *nn.Linear `gguf:"qkv"`
Output *nn.Linear `gguf:"proj"`
RelativePosition *struct {
Height ml.Tensor `gguf:"h"`
Width ml.Tensor `gguf:"w"`
} `gguf:",pre:rel_pos_"`
}
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
s := make([]int32, qn*kn)
for i := range qn {
for j := range kn {
q := i * max(kn/qn, 1)
k := j * max(qn/kn, 1)
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
}
}
return ctx.Input().FromInts(s, qn*kn)
}
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
maxRelativeDistance := 2*max(qn, kn) - 1
if positions.Dim(1) != maxRelativeDistance {
// linear interpolation kernel not available so approx. with bilinear interpolation
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
}
rc := relativeCoordinates(ctx, qn, kn)
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
}
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
qh, qw := qn[0], qn[1]
kh, kw := kn[0], kn[1]
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
return rh, rw
}
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
qkv := m.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
}
type samMLP struct {
Lin1 *nn.Linear `gguf:"lin1"`
Lin2 *nn.Linear `gguf:"lin2"`
}
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
}
type LayerNorm2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
u := x.Mean(ctx)
d := x.Sub(ctx, u)
s := d.Sqr(ctx).Mean(ctx)
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
type samNeck struct {
C1 *nn.Conv2D `gguf:"0"`
LN1 *LayerNorm2D `gguf:"1"`
C2 *nn.Conv2D `gguf:"2"`
LN2 *LayerNorm2D `gguf:"3"`
}
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}

View File

@@ -0,0 +1,140 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
type textModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Blocks []textBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Options textOptions
}
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}
type textOptions struct {
hiddenSize,
numHeads,
numKVHeads,
numExperts,
numExpertsUsed int
ropeBase,
ropeScale,
eps float32
}
func (o textOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}
type textBlock struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *textAttention
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
FeedForward textFeedForward
}
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type textAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
query := m.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
key := m.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
value := m.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}
type textFeedForward interface {
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
}
type textMoe struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
SharedExperts *textMLP `gguf:",suf:_shexp"`
}
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
indices := scores.TopK(ctx, opts.numExpertsUsed)
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
experts = m.Down.Forward(ctx, experts, indices)
experts = experts.Mul(ctx, weights)
expert := func(i int) ml.Tensor {
return experts.View(
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
)
}
routedStates := expert(0)
for i := 1; i < opts.numExpertsUsed; i++ {
routedStates = routedStates.Add(ctx, expert(i))
}
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
return routedStates.Add(ctx, sharedStates)
}
type textMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hiddenStates)
}

View File

@@ -0,0 +1,117 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type visionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
ClassEmbedding ml.Tensor `gguf:"class_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
Blocks []visionBlock `gguf:"blk"`
Options visionOptions
}
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
if source != target {
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
}
return positionEmbeds
}
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
if patchEmbeds == nil {
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
}
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
}
return hiddenStates
}
type visionOptions struct {
hiddenSize,
numHeads int
eps float32
imageSize, patchSize int
}
func (o visionOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type visionBlock struct {
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
Attention *visionAttention `gguf:"self_attn"`
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
FeedForward *visionMLP `gguf:"mlp"`
}
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type visionAttention struct {
QKV *nn.Linear `gguf:"qkv_proj"`
Output *nn.Linear `gguf:"out_proj"`
}
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
qkv := m.QKV.Forward(ctx, t)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}
type visionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
}

View File

@@ -3,6 +3,7 @@ package models
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/deepseekocr"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"