mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
nomic-embed-text:v2: model implementation (#13162)
This commit is contained in:
@@ -34,19 +34,23 @@ type Options struct {
|
||||
poolingType pooling.Type
|
||||
normalize bool
|
||||
ropeFreqBase float32
|
||||
|
||||
// MoE specific options (used by v2 / MoE models only)
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
moeEveryNLayers int
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
// Single Encoder Layer
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
FeedForward FeedForward
|
||||
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
@@ -56,12 +60,63 @@ type Attention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
type FeedForward interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hidden := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hidden)
|
||||
}
|
||||
|
||||
// denseGELU implements MLP with GELU activation for v2 MoE dense layers
|
||||
type denseGELU struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *denseGELU) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
// sparse implements MoE with expert routing
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
routingWeights := routerLogits.Softmax(ctx)
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
hiddenStates = moe.Up.Forward(ctx, hiddenStates, selectedExperts).GELU(ctx)
|
||||
experts := moe.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
@@ -92,7 +147,7 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates)
|
||||
hiddenStates = e.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
@@ -118,12 +173,6 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hidden := m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
|
||||
|
||||
return m.Down.Forward(ctx, hidden)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
hiddenSize := int(c.Uint("embedding_length"))
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
@@ -152,17 +201,37 @@ func New(c fs.Config) (model.Model, error) {
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
|
||||
for i := range layers {
|
||||
if moeEveryNLayers > 0 {
|
||||
// Layer uses MoE if (i+1) % moe_every_n_layers == 0
|
||||
if (i+1)%moeEveryNLayers == 0 {
|
||||
layers[i].FeedForward = &sparse{}
|
||||
} else {
|
||||
layers[i].FeedForward = &denseGELU{}
|
||||
}
|
||||
} else {
|
||||
layers[i].FeedForward = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", false),
|
||||
ropeFreqBase: c.Float("rope.freq_base", 1000.0),
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", false),
|
||||
ropeFreqBase: c.Float("rope.freq_base", 1000.0),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
moeEveryNLayers: moeEveryNLayers,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -170,4 +239,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
func init() {
|
||||
model.Register("nomic-bert", New)
|
||||
model.Register("nomic-bert_embed", New)
|
||||
model.Register("nomic-bert-moe", New)
|
||||
model.Register("nomic-bert-moe_embed", New)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user