embed: cleanup (#12299)

* cleanup

* use pooling.TypeNone

* pooling test
This commit is contained in:
Michael Yang
2025-09-16 09:48:42 -07:00
committed by GitHub
parent a1cff89b30
commit c253433d68
6 changed files with 104 additions and 19 deletions

View File

@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}

View File

@@ -22,7 +22,7 @@ type embedModel struct {
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}