mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
embed: cleanup (#12299)
* cleanup * use pooling.TypeNone * pooling test
This commit is contained in:
@@ -11,26 +11,32 @@ const (
|
|||||||
TypeMean
|
TypeMean
|
||||||
TypeCLS
|
TypeCLS
|
||||||
TypeLast
|
TypeLast
|
||||||
TypeRank
|
|
||||||
|
|
||||||
TypeUnknown = 0xFFFFFFFE
|
|
||||||
TypeUnspecified = 0xFFFFFFFF
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
|
func (t Type) String() string {
|
||||||
switch poolingType {
|
switch t {
|
||||||
case TypeNone:
|
case TypeMean:
|
||||||
return hiddenStates
|
return "Mean"
|
||||||
|
case TypeCLS:
|
||||||
|
return "CLS"
|
||||||
|
case TypeLast:
|
||||||
|
return "Last"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
|
switch t {
|
||||||
case TypeMean:
|
case TypeMean:
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
case TypeCLS:
|
case TypeCLS:
|
||||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||||
case TypeLast:
|
case TypeLast:
|
||||||
panic("not implemented")
|
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||||
case TypeRank:
|
return hiddenStates
|
||||||
panic("not implemented")
|
|
||||||
default:
|
default:
|
||||||
panic("not implemented")
|
panic("unknown pooling type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
79
ml/nn/pooling/pooling_test.go
Normal file
79
ml/nn/pooling/pooling_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package pooling_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/discover"
|
||||||
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup(tb testing.TB, n int) ml.Backend {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"test.block_count": uint32(1),
|
||||||
|
}, []*fsggml.Tensor{
|
||||||
|
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
||||||
|
}); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gpuLayers ml.GPULayersList
|
||||||
|
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
|
||||||
|
gpuLayers = append(gpuLayers, ml.GPULayers{
|
||||||
|
ID: gpus[0].ID,
|
||||||
|
Layers: slices.Collect(func(yield func(int) bool) {
|
||||||
|
for i := range n {
|
||||||
|
if !yield(i) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForward(t *testing.T) {
|
||||||
|
cases := map[pooling.Type][]float32{
|
||||||
|
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
|
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
||||||
|
}
|
||||||
|
for typ, want := range cases {
|
||||||
|
t.Run(typ.String(), func(t *testing.T) {
|
||||||
|
b := setup(t, 99)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
ctx := b.NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
||||||
|
tt = typ.Forward(ctx, tt)
|
||||||
|
|
||||||
|
ctx.Forward(tt).Compute(tt)
|
||||||
|
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
||||||
|
t.Error(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"math"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -21,6 +20,7 @@ import (
|
|||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
_ "github.com/ollama/ollama/ml/backend"
|
_ "github.com/ollama/ollama/ml/backend"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
arch := b.Config().Architecture()
|
arch := b.Config().Architecture()
|
||||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
|
||||||
arch = arch + "_embed"
|
arch = arch + "_embed"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
if m.normalize {
|
if m.normalize {
|
||||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
}
|
}
|
||||||
@@ -22,7 +22,7 @@ type embedModel struct {
|
|||||||
|
|
||||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
for _, dense := range m.Dense {
|
for _, dense := range m.Dense {
|
||||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -32,6 +31,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||||
|
|
||||||
var activeBatch batchState
|
var activeBatch batchState
|
||||||
for {
|
for {
|
||||||
@@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
|
||||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user