mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
embed: cleanup (#12299)
* cleanup * use pooling.TypeNone * pooling test
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
@@ -22,7 +22,7 @@ type embedModel struct {
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user