feat: qwen3 embed (#12301)

* cleanup

* use pooling.TypeNone

* pooling test

* qwen3 embed
This commit is contained in:
Michael Yang
2025-09-18 15:50:32 -07:00
committed by GitHub
parent 22ccdd74c2
commit 7460259eb3
2 changed files with 88 additions and 4 deletions

View File

@@ -151,14 +151,25 @@ type Model struct {
*Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
return m.Output.Forward(ctx, hiddenStates), nil
}
// Forward implements model.Model.
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {
@@ -168,8 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
@@ -227,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
func init() {
model.Register("qwen3", New)
model.Register("qwen3moe", New)
model.Register("qwen3_embed", newEmbed)
}