deepseekocr

This commit is contained in:
Michael Yang
2025-10-31 19:15:32 -07:00
committed by Michael Yang
parent 8ed1adf3db
commit 92981ae3f2
14 changed files with 975 additions and 7 deletions

View File

@@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &commandrModel{} conv = &commandrModel{}
case "GptOssForCausalLM": case "GptOssForCausalLM":
conv = &gptossModel{} conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
default: default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
} }

View File

@@ -0,0 +1,136 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" || t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" || t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" || t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" { t.name == "v.post_tile_position_embd.weight" ||
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
// these tensors are always F32 // these tensors are always F32
return tensorKindFP32 return tensorKindFP32
} }

View File

@@ -96,7 +96,10 @@ type safetensor struct {
func (st safetensor) Kind() uint32 { func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind() kind := st.tensorBase.Kind()
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
kind = tensorKindBF16 kind = tensorKindBF16
} }

View File

@@ -249,6 +249,7 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl", "qwen25vl",
"qwen3", "qwen3moe", "qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe", "qwen3vl", "qwen3vlmoe",
"deepseekocr",
}, kv.Architecture()) }, kv.Architecture())
} }

View File

@@ -173,6 +173,7 @@ type Tensor interface {
Cos(ctx Context) Tensor Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor GELU(ctx Context, up ...Tensor) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor Sigmoid(ctx Context) Tensor
@@ -207,6 +208,8 @@ type Tensor interface {
Stddev(ctx Context) Tensor Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor Sqrt(ctx Context) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
} }
// ScaledDotProductAttention implements a fused attention // ScaledDotProductAttention implements a fused attention
@@ -372,3 +375,10 @@ const (
DTypeI32 DTypeI32
DTypeMXFP4 DTypeMXFP4
) )
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

View File

@@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"altup_proj", "altup_unembd_proj", "altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"): "per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
createTensor(tensor{source: t}, output.bts, blocks) createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
// TODO: assign vision tensors to the gpu if possible // TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, output.bts, blocks) createTensor(tensor{source: t}, output.bts, blocks)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
@@ -1567,6 +1567,16 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
} else {
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 { if len(t2) > 0 {
return &Tensor{ return &Tensor{
@@ -1724,6 +1734,23 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {
case ml.SamplingModeNearest:
mode = C.GGML_SCALE_MODE_NEAREST
case ml.SamplingModeBilinear:
mode = C.GGML_SCALE_MODE_BILINEAR
default:
panic("unsupported interpolate mode")
}
return &Tensor{
b: t.b,
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps. // Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range. // Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. // If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.

View File

@@ -25,12 +25,15 @@ const (
// Composite returns an image with the alpha channel removed by drawing over a white background. // Composite returns an image with the alpha channel removed by drawing over a white background.
func Composite(img image.Image) image.Image { func Composite(img image.Image) image.Image {
dst := image.NewRGBA(img.Bounds())
white := color.RGBA{255, 255, 255, 255} white := color.RGBA{255, 255, 255, 255}
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src) return CompositeColor(img, white)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over) }
// CompositeColor returns an image with the alpha channel removed by drawing over a white background.
func CompositeColor(img image.Image, color color.Color) image.Image {
dst := image.NewRGBA(img.Bounds())
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return dst return dst
} }
@@ -55,6 +58,31 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image {
return dst return dst
} }
// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color.
func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
var minPoint, maxPoint image.Point
if img.Bounds().Dx() > img.Bounds().Dy() {
// landscape
height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx()
minPoint = image.Point{0, (newSize.Y - height) / 2}
maxPoint = image.Point{newSize.X, height + minPoint.Y}
} else {
// portrait
width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy()
minPoint = image.Point{(newSize.X - width) / 2, 0}
maxPoint = image.Point{minPoint.X + width, newSize.Y}
}
kernel.Scale(dst, image.Rectangle{
Min: minPoint,
Max: maxPoint,
}, img, img.Bounds(), draw.Over, nil)
return dst
}
// Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value. // Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value.
func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 { func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 {
var pixelVals []float32 var pixelVals []float32

View File

@@ -0,0 +1,83 @@
package deepseekocr
import (
"bytes"
"image"
"image/color"
"math"
"slices"
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ratio struct {
x, y int
}
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
img, _, err := image.Decode(bytes.NewReader(bts))
if err != nil {
return nil, nil, nil, err
}
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
var targetRatios []ratio
for n := minNum; n <= maxNum; n++ {
for i := 1; i <= n; i++ {
for j := 1; j <= n; j++ {
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
targetRatios = append(targetRatios, ratio{i, j})
}
}
}
}
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
blocks := targetRatio.x * targetRatio.y
mean := imageproc.ImageNetStandardMean
std := imageproc.ImageNetStandardSTD
var patches []float32
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
for i := range blocks {
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
draw.Draw(patch, patch.Bounds(), resized, image.Point{
X: i % (targetWidth / imageSize) * imageSize,
Y: i / (targetWidth / imageSize) * imageSize,
}, draw.Over)
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
}
img = imageproc.CompositeColor(img, color.Gray{})
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
[]int{targetRatio.x, targetRatio.y},
nil
}
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
bestDiff := math.MaxFloat64
best := ratio{1, 1}
realRatio := float64(width) / float64(height)
for _, target := range targetRatios {
targetRatio := float64(target.x) / float64(target.y)
diff := math.Abs(realRatio - targetRatio)
if diff < bestDiff {
bestDiff = diff
best = target
} else if diff == bestDiff {
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
best = target
}
}
}
return best
}

View File

@@ -0,0 +1,192 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
Text *textModel
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
Projector *nn.Linear `gguf:"mm.layers"`
}
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
patches, original, crop, err := ProcessImage(ctx, bts)
if err != nil {
return nil, err
}
var outputs []ml.Tensor
if true { // TODO: local features if sum(patches) != 0
samOutputs := m.Sam.Forward(ctx, patches)
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
localOutputs = m.Projector.Forward(ctx, localOutputs)
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
outputs = append(outputs, localOutputs)
}
samOutputs := m.Sam.Forward(ctx, original)
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
outputs = append(outputs, globalOutputs, m.ViewSeperator)
return []input.Multimodal{
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
outputs := make([]*input.Input, 0, len(inputs))
for i := range inputs {
if inputs[i].Multimodal == nil {
outputs = append(outputs, inputs[i])
continue
}
t := inputs[i].Multimodal[0].Tensor
outputs = append(outputs, &input.Input{
Token: 128815,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1) - 1,
})
outputs = slices.Grow(outputs, t.Dim(1)-1)
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
}
return outputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, mm := range batch.Multimodal {
t := mm.Multimodal[0].Tensor
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
}
hiddenStates := inputsEmbeds
for i, block := range m.Text.Blocks {
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Text.Blocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
}
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
return m.Text.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
textBlocks := make([]textBlock, c.Uint("block_count"))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
for i := range textBlocks {
if i >= leadingDenseBlockCount {
textBlocks[i].FeedForward = &textMoe{}
} else {
textBlocks[i].FeedForward = &textMLP{}
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Text: &textModel{
Blocks: textBlocks,
Options: textOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeBase: c.Float("rope.freq_base", 10_000),
ropeScale: c.Float("rope.scaling.factor", 1.0),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
},
},
Vision: &visionModel{
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
Options: visionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.head_count")),
imageSize: int(c.Uint("vision.image_size", 224)),
patchSize: int(c.Uint("vision.patch_size", 14)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
},
},
Sam: &samModel{
Blocks: make([]samBlock, c.Uint("sam.block_count")),
Options: samOptions{
hiddenSize: int(c.Uint("sam.embedding_length")),
numHeads: int(c.Uint("sam.head_count")),
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
},
},
}
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
return &m, nil
})
}

View File

@@ -0,0 +1,225 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type samModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd"`
Blocks []samBlock `gguf:"blk"`
Neck *samNeck `gguf:"neck"`
Net2 *nn.Conv2D `gguf:"net_2"`
Net3 *nn.Conv2D `gguf:"net_3"`
Options samOptions
}
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
source := m.PositionEmbedding.Dim(1)
target := hiddenStates.Dim(2)
if source != target {
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
return m.PositionEmbedding
}
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
}
for i, block := range m.Blocks {
var windowSize int
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
windowSize = 14
}
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
}
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
return hiddenStates
}
type samOptions struct {
hiddenSize,
numHeads int
eps float32
globalAttentionLayers []int32
}
func (o samOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type samBlock struct {
Norm1 *nn.LayerNorm `gguf:"norm1"`
Attention *samAttention `gguf:"attn"`
Norm2 *nn.LayerNorm `gguf:"norm2"`
FeedForward *samMLP `gguf:"mlp"`
}
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
var pw, ph int
if windowSize > 0 {
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
if pw > 0 || ph > 0 {
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
}
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
}
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
if windowSize > 0 {
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type samAttention struct {
QKV *nn.Linear `gguf:"qkv"`
Output *nn.Linear `gguf:"proj"`
RelativePosition *struct {
Height ml.Tensor `gguf:"h"`
Width ml.Tensor `gguf:"w"`
} `gguf:",pre:rel_pos_"`
}
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
s := make([]int32, qn*kn)
for i := range qn {
for j := range kn {
q := i * max(kn/qn, 1)
k := j * max(qn/kn, 1)
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
}
}
return ctx.Input().FromInts(s, qn*kn)
}
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
maxRelativeDistance := 2*max(qn, kn) - 1
if positions.Dim(1) != maxRelativeDistance {
// linear interpolation kernel not available so approx. with bilinear interpolation
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
}
rc := relativeCoordinates(ctx, qn, kn)
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
}
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
qh, qw := qn[0], qn[1]
kh, kw := kn[0], kn[1]
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
return rh, rw
}
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
qkv := m.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
}
type samMLP struct {
Lin1 *nn.Linear `gguf:"lin1"`
Lin2 *nn.Linear `gguf:"lin2"`
}
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
}
type LayerNorm2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
u := x.Mean(ctx)
d := x.Sub(ctx, u)
s := d.Sqr(ctx).Mean(ctx)
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
type samNeck struct {
C1 *nn.Conv2D `gguf:"0"`
LN1 *LayerNorm2D `gguf:"1"`
C2 *nn.Conv2D `gguf:"2"`
LN2 *LayerNorm2D `gguf:"3"`
}
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}

View File

@@ -0,0 +1,140 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
type textModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Blocks []textBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Options textOptions
}
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}
type textOptions struct {
hiddenSize,
numHeads,
numKVHeads,
numExperts,
numExpertsUsed int
ropeBase,
ropeScale,
eps float32
}
func (o textOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}
type textBlock struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *textAttention
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
FeedForward textFeedForward
}
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type textAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
query := m.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
key := m.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
value := m.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}
type textFeedForward interface {
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
}
type textMoe struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
SharedExperts *textMLP `gguf:",suf:_shexp"`
}
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
indices := scores.TopK(ctx, opts.numExpertsUsed)
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
experts = m.Down.Forward(ctx, experts, indices)
experts = experts.Mul(ctx, weights)
expert := func(i int) ml.Tensor {
return experts.View(
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
)
}
routedStates := expert(0)
for i := 1; i < opts.numExpertsUsed; i++ {
routedStates = routedStates.Add(ctx, expert(i))
}
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
return routedStates.Add(ctx, sharedStates)
}
type textMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hiddenStates)
}

View File

@@ -0,0 +1,117 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type visionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
ClassEmbedding ml.Tensor `gguf:"class_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
Blocks []visionBlock `gguf:"blk"`
Options visionOptions
}
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
if source != target {
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
}
return positionEmbeds
}
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
if patchEmbeds == nil {
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
}
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
}
return hiddenStates
}
type visionOptions struct {
hiddenSize,
numHeads int
eps float32
imageSize, patchSize int
}
func (o visionOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type visionBlock struct {
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
Attention *visionAttention `gguf:"self_attn"`
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
FeedForward *visionMLP `gguf:"mlp"`
}
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type visionAttention struct {
QKV *nn.Linear `gguf:"qkv_proj"`
Output *nn.Linear `gguf:"out_proj"`
}
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
qkv := m.QKV.Forward(ctx, t)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}
type visionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
}

View File

@@ -3,6 +3,7 @@ package models
import ( import (
_ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2" _ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/deepseekocr"
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"