package deepseekocr import ( "math" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" ) type textModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Blocks []textBlock `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output"` Options textOptions } func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil } type textOptions struct { hiddenSize, numHeads, numKVHeads, numExperts, numExpertsUsed int ropeBase, ropeScale, eps float32 } func (o textOptions) headDim() int { return o.hiddenSize / o.numHeads } func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) } type textBlock struct { AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` Attention *textAttention MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"` FeedForward textFeedForward } func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor { residual := hiddenStates hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts) if outputs != nil { hiddenStates = hiddenStates.Rows(ctx, outputs) residual = residual.Rows(ctx, outputs) } hiddenStates = hiddenStates.Add(ctx, residual) residual = hiddenStates hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps) hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts) return hiddenStates.Add(ctx, residual) } type textAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor { query := m.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1) key := m.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) value := m.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, -1, attention.Dim(2)) return m.Output.Forward(ctx, attention) } type textFeedForward interface { Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor } type textMoe struct { Router *nn.Linear `gguf:"ffn_gate_inp"` Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` Up *nn.LinearBatch `gguf:"ffn_up_exps"` Down *nn.LinearBatch `gguf:"ffn_down_exps"` SharedExperts *textMLP `gguf:",suf:_shexp"` } func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor { scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx) indices := scores.TopK(ctx, opts.numExpertsUsed) weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices) experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices)) experts = m.Down.Forward(ctx, experts, indices) experts = experts.Mul(ctx, weights) expert := func(i int) ml.Tensor { return experts.View( ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2), ) } routedStates := expert(0) for i := 1; i < opts.numExpertsUsed; i++ { routedStates = routedStates.Add(ctx, expert(i)) } sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts) return routedStates.Add(ctx, sharedStates) } type textMLP struct { Gate *nn.Linear `gguf:"ffn_gate"` Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` } func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor { hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates)) return m.Down.Forward(ctx, hiddenStates) }