mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-25 07:58:01 +00:00
Merge branch 'main' into drifkin/array-head-count-simple
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -23,7 +24,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -425,9 +425,14 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
|
||||
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
|
||||
ft := layer.GGML.KV().FileType()
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType)})
|
||||
|
||||
want, err := ggml.ParseFileType(quantizeType)
|
||||
var doneBytes atomic.Uint64
|
||||
totalBytes := uint64(layer.Size) - layer.GGML.Tensors().Offset
|
||||
fnWrap := func(n uint64) {
|
||||
done := doneBytes.Add(n)
|
||||
progress := float32(done) / float32(totalBytes)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
|
||||
}
|
||||
ftype, err := ggml.ParseFileType(quantizeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -436,6 +441,11 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fp, err := os.Open(blob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
temp, err := os.CreateTemp(filepath.Dir(blob), quantizeType)
|
||||
if err != nil {
|
||||
@@ -444,15 +454,15 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if err := llama.Quantize(blob, temp.Name(), uint32(want)); err != nil {
|
||||
if err := quantize(fp, temp, layer.GGML, ftype, fnWrap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -462,7 +472,6 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &layerGGML{newLayer, f}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +106,11 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
// Check for vision capability in projector-based models
|
||||
if len(m.ProjectorPaths) > 0 {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -91,11 +92,7 @@ func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create different types of mock model files
|
||||
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||
@@ -104,21 +101,13 @@ func TestModelCapabilities(t *testing.T) {
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
|
||||
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
if err := errors.Join(
|
||||
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
@@ -236,27 +225,18 @@ func TestModelCapabilities(t *testing.T) {
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
tempDir := t.TempDir()
|
||||
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vision model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
if err := errors.Join(
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
// safetensors provides a reader for the safetensor directories and files.
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tensor represents a single tensor in a safetensors file.
|
||||
//
|
||||
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
|
||||
//
|
||||
// It is not safe for use across multiple goroutines.
|
||||
type Tensor struct {
|
||||
name string
|
||||
dataType string
|
||||
shape []int64
|
||||
|
||||
fsys fs.FS
|
||||
fname string // entry name in fsys
|
||||
offset int64
|
||||
size int64
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
fsys fs.FS
|
||||
}
|
||||
|
||||
func Read(fsys fs.FS) (*Model, error) {
|
||||
return &Model{fsys: fsys}, nil
|
||||
}
|
||||
|
||||
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
|
||||
return func(yield func(*Tensor, error) bool) {
|
||||
entries, err := fs.Glob(m.fsys, "*.safetensors")
|
||||
if err != nil {
|
||||
yield(nil, err)
|
||||
return
|
||||
}
|
||||
for _, e := range entries {
|
||||
tt, err := m.readTensors(e)
|
||||
if err != nil {
|
||||
yield(nil, err)
|
||||
return
|
||||
}
|
||||
for _, t := range tt {
|
||||
if !yield(t, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
f, err := m.fsys.Open(fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
finfo, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headerSize, err := readInt64(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
_, err = io.ReadFull(f, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var raws map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raws); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||
|
||||
// TODO(bmizerany): do something with metadata? This could be another
|
||||
// header read if needed. We also need to figure out if the metadata is
|
||||
// present in only one .safetensors file or if each file may have their
|
||||
// own and if it needs to follow each tensor. Currently, I (bmizerany)
|
||||
// am only seeing them show up with one entry for file type which is
|
||||
// always "pt".
|
||||
|
||||
tt := make([]*Tensor, 0, len(raws))
|
||||
for name, raw := range raws {
|
||||
if name == "__metadata__" {
|
||||
// TODO(bmizerany): do something with metadata?
|
||||
continue
|
||||
}
|
||||
var v struct {
|
||||
DataType string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
Offsets []int64 `json:"data_offsets"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
|
||||
}
|
||||
if len(v.Offsets) != 2 {
|
||||
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
|
||||
}
|
||||
|
||||
// TODO(bmizerany): after collecting, validate all offests make
|
||||
// tensors contiguous?
|
||||
begin := endOfHeader + v.Offsets[0]
|
||||
end := endOfHeader + v.Offsets[1]
|
||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
|
||||
tt = append(tt, &Tensor{
|
||||
name: name,
|
||||
dataType: v.DataType,
|
||||
shape: v.Shape,
|
||||
fsys: m.fsys,
|
||||
fname: fname,
|
||||
offset: begin,
|
||||
size: end - begin,
|
||||
})
|
||||
}
|
||||
return tt, nil
|
||||
}
|
||||
|
||||
func checkBeginEnd(size, begin, end int64) error {
|
||||
if begin < 0 {
|
||||
return fmt.Errorf("begin must not be negative: %d", begin)
|
||||
}
|
||||
if end < 0 {
|
||||
return fmt.Errorf("end must not be negative: %d", end)
|
||||
}
|
||||
if end < begin {
|
||||
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
|
||||
}
|
||||
if end > size {
|
||||
return fmt.Errorf("end must be <= size: %d > %d", end, size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readInt64(r io.Reader) (int64, error) {
|
||||
var v uint64
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := range buf {
|
||||
v |= uint64(buf[i]) << (8 * i)
|
||||
}
|
||||
return int64(v), nil
|
||||
}
|
||||
|
||||
type Shape []int64
|
||||
|
||||
func (s Shape) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
for i, v := range s {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(strconv.FormatInt(v, 10))
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *Tensor) Name() string { return t.name }
|
||||
func (t *Tensor) DataType() string { return t.dataType }
|
||||
func (t *Tensor) Size() int64 { return t.size }
|
||||
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
|
||||
|
||||
func (t *Tensor) Reader() (io.ReadCloser, error) {
|
||||
f, err := t.fsys.Open(t.fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := newSectionReader(f, t.offset, t.size)
|
||||
rc := struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{r, f}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
// newSectionReader returns a new io.Reader that reads from r starting at
|
||||
// offset. It is a convenience function for creating a io.SectionReader when r
|
||||
// may not be an io.ReaderAt.
|
||||
//
|
||||
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
|
||||
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
|
||||
// beginning of the file.
|
||||
//
|
||||
// If r is an io.Seeker,
|
||||
// or slow path. The slow path is used when r does not implement io.ReaderAt,
|
||||
// in which case it must discard the data it reads.
|
||||
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
|
||||
if r, ok := r.(io.ReaderAt); ok {
|
||||
return io.NewSectionReader(r, offset, n)
|
||||
}
|
||||
if r, ok := r.(io.ReadSeeker); ok {
|
||||
r.Seek(offset, io.SeekStart)
|
||||
return io.LimitReader(r, n)
|
||||
}
|
||||
// Discard to offset and return a limited reader.
|
||||
_, err := io.CopyN(io.Discard, r, offset)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return io.LimitReader(r, n)
|
||||
}
|
||||
@@ -1,375 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var stdout io.Writer = os.Stdout
|
||||
|
||||
const usage = `Opp is a tool for pushing and pulling Ollama models.
|
||||
|
||||
Usage:
|
||||
|
||||
opp [flags] <push|pull|import>
|
||||
|
||||
Commands:
|
||||
|
||||
push Upload a model to the Ollama server.
|
||||
pull Download a model from the Ollama server.
|
||||
import Import a model from a local safetensor directory.
|
||||
|
||||
Examples:
|
||||
|
||||
# Pull a model from the Ollama server.
|
||||
opp pull library/llama3.2:latest
|
||||
|
||||
# Push a model to the Ollama server.
|
||||
opp push username/my_model:8b
|
||||
|
||||
# Import a model from a local safetensor directory.
|
||||
opp import /path/to/safetensor
|
||||
|
||||
Envionment Variables:
|
||||
|
||||
OLLAMA_MODELS
|
||||
The directory where models are pushed and pulled from
|
||||
(default ~/.ollama/models).
|
||||
`
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprint(os.Stderr, usage)
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
flag.Usage()
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
|
||||
}
|
||||
os.Exit(2)
|
||||
return errors.New("unreachable")
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
// TODO(bmizerany): configure transport?
|
||||
rc.HTTPClient = &http.Client{Transport: tr}
|
||||
|
||||
var mu sync.Mutex
|
||||
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
|
||||
|
||||
var pb bytes.Buffer
|
||||
printProgress := func() {
|
||||
pb.Reset()
|
||||
mu.Lock()
|
||||
for d, s := range p {
|
||||
// Write progress to a buffer first to avoid blocking
|
||||
// on stdout while holding the lock.
|
||||
stamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
|
||||
if s[0] == s[1] {
|
||||
delete(p, d)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
io.Copy(stdout, &pb)
|
||||
}
|
||||
|
||||
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if err != nil && !errors.Is(err, ollama.ErrCached) {
|
||||
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
p[l.Digest] = [2]int64{l.Size, n}
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
printProgress()
|
||||
case err := <-errc:
|
||||
printProgress()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
|
||||
model := flag.Arg(0)
|
||||
if model == "" {
|
||||
return fmt.Errorf("missing model argument")
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ollama.ErrCached):
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
|
||||
case err != nil:
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
|
||||
case n == 0:
|
||||
l := m.Layer(l.Digest)
|
||||
mt, p, _ := mime.ParseMediaType(l.MediaType)
|
||||
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
|
||||
switch mt {
|
||||
case "tensor":
|
||||
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
|
||||
default:
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
|
||||
type trackingReader struct {
|
||||
io.Reader
|
||||
n *atomic.Int64
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.Reader.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("import", flag.ExitOnError)
|
||||
flagAs := flag.String("as", "", "Import using the provided name.")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
if *flagAs == "" {
|
||||
return fmt.Errorf("missing -as flag")
|
||||
}
|
||||
as := ollama.CompleteName(*flagAs)
|
||||
|
||||
dir := cmp.Or(flag.Arg(0), ".")
|
||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||
|
||||
m, err := safetensors.Read(os.DirFS(dir))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var total int64
|
||||
var tt []*safetensors.Tensor
|
||||
for t, err := range m.Tensors() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tt = append(tt, t)
|
||||
total += t.Size()
|
||||
}
|
||||
|
||||
var n atomic.Int64
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
layers := make([]*ollama.Layer, len(tt))
|
||||
var g errgroup.Group
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
var ctxErr error
|
||||
for i, t := range tt {
|
||||
if ctx.Err() != nil {
|
||||
// The context may cancel AFTER we exit the
|
||||
// loop, and so if we use ctx.Err() after the
|
||||
// loop we may report it as the error that
|
||||
// broke the loop, when it was not. This can
|
||||
// manifest as a false-negative, leading the
|
||||
// user to think their import failed when it
|
||||
// did not, so capture it if and only if we
|
||||
// exit the loop because of a ctx.Err() and
|
||||
// report it.
|
||||
ctxErr = ctx.Err()
|
||||
break
|
||||
}
|
||||
g.Go(func() (err error) {
|
||||
rc, err := t.Reader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rc.Close()
|
||||
tr := &trackingReader{rc, &n}
|
||||
d, err := c.Import(tr, t.Size())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers[i] = &ollama.Layer{
|
||||
Digest: d,
|
||||
Size: t.Size(),
|
||||
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
|
||||
"name": t.Name(),
|
||||
"dtype": t.DataType(),
|
||||
"shape": t.Shape().String(),
|
||||
}),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
done <- func() error {
|
||||
if err := errors.Join(g.Wait(), ctxErr); err != nil {
|
||||
return err
|
||||
}
|
||||
m := &ollama.Manifest{Layers: layers}
|
||||
data, err := json.MarshalIndent(m, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d := blob.DigestFromBytes(data)
|
||||
err = blob.PutBytes(c, d, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Link(as, d)
|
||||
}()
|
||||
}()
|
||||
|
||||
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
|
||||
|
||||
csiHideCursor(stdout)
|
||||
defer csiShowCursor(stdout)
|
||||
|
||||
csiSavePos(stdout)
|
||||
writeProgress := func() {
|
||||
csiRestorePos(stdout)
|
||||
nn := n.Load()
|
||||
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
|
||||
formatNatural(nn),
|
||||
formatNatural(total),
|
||||
nn*100/total,
|
||||
ansiClearToEndOfLine,
|
||||
)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
writeProgress()
|
||||
case err := <-done:
|
||||
writeProgress()
|
||||
fmt.Println()
|
||||
fmt.Println("Successfully imported", as)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatNatural(n int64) string {
|
||||
switch {
|
||||
case n < 1024:
|
||||
return fmt.Sprintf("%d B", n)
|
||||
case n < 1024*1024:
|
||||
return fmt.Sprintf("%.1f KB", float64(n)/1024)
|
||||
case n < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
|
||||
default:
|
||||
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
const ansiClearToEndOfLine = "\033[K"
|
||||
|
||||
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
|
||||
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
|
||||
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
|
||||
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }
|
||||
@@ -3,7 +3,6 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
@@ -29,7 +28,7 @@ func TestLoopAllocs(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkLoop(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
ctx := b.Context()
|
||||
synctest.Run(func() {
|
||||
for n := range Loop(ctx, 100*time.Millisecond) {
|
||||
if n == b.N {
|
||||
|
||||
@@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
f, _, err := ggml.Decode(blob, 1024)
|
||||
f, _, err := ggml.Decode(blob, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -11,9 +10,7 @@ import (
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
dir, err := os.MkdirTemp("", "ollama-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -24,19 +21,19 @@ func TestGetBlobsPath(t *testing.T) {
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(dir, "blobs"),
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
@@ -60,7 +57,7 @@ func TestGetBlobsPath(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", dir)
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"image"
|
||||
"image/png"
|
||||
"testing"
|
||||
@@ -318,7 +317,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
274
server/quantization.go
Normal file
274
server/quantization.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
||||
|
||||
type quantizer struct {
|
||||
*os.File
|
||||
offset uint64
|
||||
from, to *fsggml.Tensor
|
||||
progressFn func(n uint64)
|
||||
}
|
||||
|
||||
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
quantize := q.from.Kind != q.to.Kind
|
||||
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
||||
if !quantize {
|
||||
n, err := io.Copy(w, sr)
|
||||
q.progressFn(q.from.Size())
|
||||
return n, err
|
||||
}
|
||||
data, err := io.ReadAll(sr)
|
||||
if err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||
}
|
||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
||||
n, err := w.Write(data)
|
||||
q.progressFn(q.from.Size())
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
nAttnV int // Number of attn_*v* weight tensors
|
||||
nFfnDown int // Number of ffn_down tensors
|
||||
iAttnV int // Running counter of number of attn_v tensors that have been processed
|
||||
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
|
||||
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
|
||||
}
|
||||
|
||||
func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
|
||||
// Ported from llama_tensor_get_type, removed unsupported quantization types
|
||||
nExperts := max(1, kv.Uint("expert_count", 0))
|
||||
if name == "output.weight" || name == "output_norm.weight" || (!qs.hasOutput && name == "token_embd.weight") {
|
||||
nx := shape[0]
|
||||
qk_k := newType.BlockSize()
|
||||
if nx%qk_k != 0 {
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if newType != fsggml.TensorTypeQ8_0 {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if strings.Contains(name, "attn_v.weight") {
|
||||
if ftype == fsggml.FileTypeQ2_K {
|
||||
if kv.GQA() >= 4 {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
} else {
|
||||
newType = fsggml.TensorTypeQ3_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ2_K_S && kv.GQA() >= 4 {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
} else if ftype == fsggml.FileTypeQ3_K_M {
|
||||
if qs.iAttnV < 2 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
} else {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ3_K_L {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
} else if (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ5_K_M) &&
|
||||
useMoreBits(qs.iAttnV, qs.nAttnV) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
|
||||
// TODO
|
||||
// if (qs.model.type == LLM_TYPE_70B) {
|
||||
// // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
|
||||
// // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
|
||||
// // nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||
// if (newType == GGML_TYPE_Q3_K || newType == GGML_TYPE_Q4_K) newType = GGML_TYPE_Q5_K;
|
||||
// }
|
||||
|
||||
if nExperts == 8 {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
qs.iAttnV++
|
||||
} else if strings.Contains(name, "attn_k.weight") {
|
||||
if nExperts == 8 {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
n_layer := qs.nFfnDown
|
||||
if ftype == fsggml.FileTypeQ2_K {
|
||||
newType = fsggml.TensorTypeQ3_K
|
||||
} else if ftype == fsggml.FileTypeQ2_K_S {
|
||||
if iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ3_K_M {
|
||||
if iLayer < n_layer/16 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
} else if useMoreBits(iLayer, n_layer) {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
} else {
|
||||
newType = fsggml.TensorTypeQ3_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ3_K_L {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
} else if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ5_K_M && useMoreBits(iLayer, n_layer) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
qs.iFfnDown++
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ2_K || ftype == fsggml.FileTypeQ3_K_S || ftype == fsggml.FileTypeQ3_K_M ||
|
||||
ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
} else {
|
||||
if ftype == fsggml.FileTypeQ2_K {
|
||||
newType = fsggml.TensorTypeQ3_K
|
||||
} else if ftype == fsggml.FileTypeQ3_K_M {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
} else if ftype == fsggml.FileTypeQ3_K_L {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(name, "attn_qkv.weight") {
|
||||
if ftype == fsggml.FileTypeQ3_K_M || ftype == fsggml.FileTypeQ3_K_L {
|
||||
newType = fsggml.TensorTypeQ4_K
|
||||
} else if ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
} else if ftype == fsggml.FileTypeQ5_K_M {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
}
|
||||
|
||||
if newType.IsQuantized() {
|
||||
nx := shape[0]
|
||||
ny := uint64(1)
|
||||
if len(shape) > 1 {
|
||||
ny = shape[1]
|
||||
}
|
||||
qk_k := newType.BlockSize()
|
||||
if nx%qk_k != 0 {
|
||||
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String()))
|
||||
newType = fsggml.TensorTypeF16
|
||||
}
|
||||
}
|
||||
return newType
|
||||
}
|
||||
|
||||
func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType, progressFn func(n uint64)) error {
|
||||
kv := maps.Clone(orig.KV())
|
||||
kv["general.file_type"] = newFileType
|
||||
// kv["general.quantization_version"] = ggml.QuantizationVersion()
|
||||
qs := &quantizeState{}
|
||||
// Build up the quantize state so newType can adjust types
|
||||
layerCount := 0
|
||||
for k, l := range orig.Tensors().GroupLayers() {
|
||||
if strings.HasPrefix(k, "blk.") {
|
||||
layerCount++
|
||||
}
|
||||
for _, tensor := range l {
|
||||
if strings.Contains(tensor.Name, "attn_v.weight") ||
|
||||
strings.Contains(tensor.Name, "attn_qkv.weight") ||
|
||||
strings.Contains(tensor.Name, "attn_kv_b.weight") {
|
||||
qs.nAttnV++
|
||||
} else if tensor.Name == "output.weight" {
|
||||
qs.hasOutput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
qs.nFfnDown = layerCount
|
||||
|
||||
origTensors := orig.Tensors().Items()
|
||||
outputTensors := make([]*fsggml.Tensor, len(origTensors))
|
||||
for i, tensor := range origTensors {
|
||||
tensor := tensor
|
||||
newType := newType(tensor, kv, qs, newFileType)
|
||||
newTensor := &fsggml.Tensor{
|
||||
Name: tensor.Name,
|
||||
Shape: tensor.Shape,
|
||||
Kind: uint32(newType),
|
||||
}
|
||||
outputTensors[i] = newTensor
|
||||
outputTensors[i].WriterTo = quantizer{
|
||||
File: in,
|
||||
offset: orig.Tensors().Offset + tensor.Offset,
|
||||
from: tensor,
|
||||
to: newTensor,
|
||||
progressFn: progressFn,
|
||||
}
|
||||
}
|
||||
return fsggml.WriteGGUF(out, kv, outputTensors)
|
||||
}
|
||||
|
||||
func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.FileType) fsggml.TensorType {
|
||||
defaultType := ftype.ToTensorType()
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize = quantize && (len(t.Shape) >= 2)
|
||||
|
||||
// do not quantize norm tensors
|
||||
quantize = quantize && !strings.Contains(name, "_norm.weight")
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
|
||||
|
||||
// do not quantize positional embeddings and token types (BERT)
|
||||
quantize = quantize && (name != "position_embd.weight")
|
||||
quantize = quantize && (name != "token_types.weight")
|
||||
|
||||
// do not quantize Mamba's small yet 2D weights
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w2.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_decay_w1.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_decay_w2.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_lerp_fused.weight")
|
||||
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
|
||||
if newType != defaultType {
|
||||
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
|
||||
}
|
||||
}
|
||||
return newType
|
||||
}
|
||||
882
server/quantization_test.go
Normal file
882
server/quantization_test.go
Normal file
@@ -0,0 +1,882 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
||||
|
||||
func TestGetTensorNewType(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
kv map[string]any
|
||||
qs quantizeState
|
||||
newType fsggml.TensorType
|
||||
tensor_name string
|
||||
shape []uint64
|
||||
ftype fsggml.FileType
|
||||
expected fsggml.TensorType
|
||||
expectedPanic string
|
||||
}{
|
||||
{
|
||||
name: "output_unsupported",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "output.weight",
|
||||
shape: []uint64{100, 100},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeF16,
|
||||
},
|
||||
{
|
||||
name: "output_Q8",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "output.weight",
|
||||
shape: []uint64{1024, 1024},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.attention.head_count": uint32(4),
|
||||
"foo.attention.head_count_kv": uint32(1),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q3_k",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K,
|
||||
expected: fsggml.TensorTypeQ3_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q2_k_s_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.attention.head_count": uint32(4),
|
||||
"foo.attention.head_count_kv": uint32(1),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K_S,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q3_k_m",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q3_k_m_i",
|
||||
qs: quantizeState{
|
||||
iAttnV: 2,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q3_k_l",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_L,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q4_k_m",
|
||||
qs: quantizeState{
|
||||
iAttnV: 2,
|
||||
nAttnV: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q4_k_s",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_S,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_8_expert",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.expert_count": uint32(8),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
{
|
||||
name: "attn_k.weight_8_expert",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.expert_count": uint32(8),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_k.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q2_k",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K,
|
||||
expected: fsggml.TensorTypeQ3_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q2_k_s",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K_S,
|
||||
expected: fsggml.TensorTypeQ4_0,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q2_k_s_layers",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K_S,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q3_k_m_base",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 1,
|
||||
nFfnDown: 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ3_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q3_k_m_16",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 16,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q3_k_m_8",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q3_k_l",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_L,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_m",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 1,
|
||||
nFfnDown: 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_0,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_m_6",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q5_k_m",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ5_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_s",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_S,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_output.weight_8_expert",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.expert_count": uint32(8),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_output.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_output.weight_q2",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_output.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ2_K,
|
||||
expected: fsggml.TensorTypeQ3_K,
|
||||
},
|
||||
{
|
||||
name: "attn_output.weight_q3_k_m",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_output.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "attn_output.weight_q3_k_l",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_output.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_L,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_qkv.weight_q3_k_m",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ3_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "attn_qkv.weight_q4_k_m",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_qkv.weight_q5_k_m",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ5_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectedPanic != "" {
|
||||
defer func() {
|
||||
e := recover()
|
||||
if !strings.Contains(fmt.Sprintf("%v", e), tt.expectedPanic) {
|
||||
t.Fatalf("incorrect panic\ngot: %v\nexpected: %s", e, tt.expectedPanic)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
defer func() {
|
||||
e := recover()
|
||||
if e != nil {
|
||||
t.Fatalf("hit unexpected panic %v", e)
|
||||
}
|
||||
}()
|
||||
}
|
||||
ret := getTensorNewType(tt.kv, &tt.qs, tt.newType, tt.tensor_name, tt.shape, tt.ftype)
|
||||
if ret != tt.expected {
|
||||
t.Fatalf("incorrect type returned\ngot: %d\nexpected: %d", ret, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
kv map[string]any
|
||||
tensors []*fsggml.Tensor
|
||||
newType string
|
||||
expectedTensorTypes map[string]fsggml.TensorType
|
||||
}{
|
||||
{
|
||||
name: "f16_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 4},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn.weight": fsggml.TensorTypeQ4_K,
|
||||
"output.weight": fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn_v.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512},
|
||||
WriterTo: bytes.NewReader(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn_v.weight": fsggml.TensorTypeQ6_K,
|
||||
"output.weight": fsggml.TensorTypeF32,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f16_q8_0",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{32, 16, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 4},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
},
|
||||
newType: "Q8_0",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn.weight": fsggml.TensorTypeQ8_0,
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := os.CreateTemp(t.TempDir(), tt.name)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer f.Close()
|
||||
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create initial model: %s", err)
|
||||
}
|
||||
fp, err := os.Open(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer fp.Close()
|
||||
meta, _, err := fsggml.Decode(fp, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
progressCalled := false
|
||||
progress := func(n uint64) {
|
||||
// fmt.Fprintf(os.Stderr, "progress: %f\n", p)
|
||||
progressCalled = true
|
||||
}
|
||||
tmp, err := os.CreateTemp(t.TempDir(), tt.name+".out")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer tmp.Close()
|
||||
ftype, err := fsggml.ParseFileType(tt.newType)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = quantize(fp, tmp, meta, ftype, progress)
|
||||
if err != nil {
|
||||
t.Fatalf("error during quantize: %s", err)
|
||||
}
|
||||
if !progressCalled {
|
||||
t.Fatalf("progress was not reported")
|
||||
}
|
||||
// Now attempt to load it back and make sure types match expected
|
||||
fpNew, err := os.Open(tmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||
}
|
||||
defer fpNew.Close()
|
||||
newMeta, _, err := fsggml.Decode(fpNew, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||
}
|
||||
tensors := newMeta.Tensors()
|
||||
for _, l := range tensors.GroupLayers() {
|
||||
for _, tensor := range l {
|
||||
if fsggml.TensorType(tensor.Kind) != tt.expectedTensorTypes[tensor.Name] {
|
||||
t.Fatalf("incorrect output type for %s\ngot:%s\nexpected:%s", tensor.Name, fsggml.TensorType(tensor.Kind), tt.expectedTensorTypes[tensor.Name])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToF32(t *testing.T) {
|
||||
expected := make([]float32, 256)
|
||||
for i := range expected {
|
||||
expected[i] = float32(i)
|
||||
}
|
||||
for dtype, data := range quantBytes {
|
||||
// Skip the no-op
|
||||
if dtype == fsggml.TensorTypeF32 {
|
||||
continue
|
||||
}
|
||||
t.Run(dtype.String(), func(t *testing.T) {
|
||||
fp32 := ggml.ConvertToF32(data, uint32(dtype), 256)
|
||||
similarity := cosineSimilarity(expected, fp32)
|
||||
if similarity < 0.999 {
|
||||
t.Fatalf("Results not similar enough: %s %f", dtype.String(), similarity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func dotProduct[V float32 | float64](v1, v2 []V) V {
|
||||
var result V = 0
|
||||
for i := range v1 {
|
||||
result += v1[i] * v2[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func magnitude[V float32 | float64](v []V) V {
|
||||
var result V = 0
|
||||
for _, val := range v {
|
||||
result += val * val
|
||||
}
|
||||
return V(math.Sqrt(float64(result)))
|
||||
}
|
||||
|
||||
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
|
||||
}
|
||||
|
||||
// Precomputed quantized data - arange 256
|
||||
// # For gguf-py supported types
|
||||
// import gguf
|
||||
// import numpy as np
|
||||
// print(repr(gguf.quantize(np.arange(256, dtype=np.float16), gguf.GGMLQuantizationType.Q4_0)))
|
||||
//
|
||||
// For types not supported by gguf-py converted via ggml_fp32_to_fp16_row and quantize_XXX
|
||||
//
|
||||
// data := make([]byte, 256*2)
|
||||
// fp32 := make([]float32, 256)
|
||||
// for i := range 256 {
|
||||
// fp32[i] = float32(i)
|
||||
// }
|
||||
// l := C.quantize_q6_K((*C.float)(&fp32[0]), unsafe.Pointer(&data[0]), 1, 256, nil)
|
||||
// for i := range data[:int(l)] {
|
||||
// fmt.Printf("%d, ", data[i])
|
||||
// }
|
||||
var (
|
||||
quantBytes = map[fsggml.TensorType][]byte{
|
||||
fsggml.TensorTypeQ4_0: {
|
||||
192, 195, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
|
||||
21, 21, 21, 4, 4, 224, 199, 36, 36, 36, 36, 19, 19,
|
||||
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 240, 201, 19,
|
||||
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
|
||||
1, 1, 240, 203, 18, 18, 18, 18, 18, 18, 18, 18, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 248, 204, 18, 18, 17, 17,
|
||||
17, 17, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 248,
|
||||
205, 17, 17, 17, 17, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 248, 206, 17, 17, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 248, 207, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1,
|
||||
},
|
||||
fsggml.TensorTypeQ4_1: {
|
||||
34, 64, 0, 0, 128, 128, 145, 145, 162, 162, 179, 179, 196,
|
||||
196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 80, 128, 128,
|
||||
145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230, 247,
|
||||
247, 34, 64, 0, 84, 128, 128, 145, 145, 162, 162, 179, 179,
|
||||
196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 86, 128,
|
||||
128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230,
|
||||
247, 247, 34, 64, 0, 88, 128, 128, 145, 145, 162, 162, 179,
|
||||
179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 89,
|
||||
128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230,
|
||||
230, 247, 247, 34, 64, 0, 90, 128, 128, 145, 145, 162, 162,
|
||||
179, 179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0,
|
||||
91, 128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213,
|
||||
230, 230, 247, 247,
|
||||
},
|
||||
fsggml.TensorTypeQ5_0: {
|
||||
192, 191, 1, 0, 0, 0, 128, 127, 127, 110, 110, 93, 93,
|
||||
76, 76, 59, 59, 42, 42, 25, 25, 8, 224, 195, 0, 0,
|
||||
0, 0, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
|
||||
21, 21, 21, 4, 4, 240, 197, 0, 0, 0, 0, 53, 37,
|
||||
37, 37, 37, 36, 36, 20, 20, 20, 20, 19, 19, 3, 3,
|
||||
3, 240, 199, 0, 0, 0, 0, 36, 36, 36, 36, 19, 19,
|
||||
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 248, 200, 0,
|
||||
0, 0, 0, 35, 19, 19, 19, 19, 19, 19, 18, 18, 18,
|
||||
18, 2, 2, 2, 2, 2, 248, 201, 0, 0, 0, 0, 19,
|
||||
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
|
||||
1, 1, 248, 202, 0, 0, 0, 0, 18, 18, 18, 18, 18,
|
||||
18, 18, 18, 18, 2, 2, 1, 1, 1, 1, 1, 248, 203,
|
||||
0, 0, 0, 0, 18, 18, 18, 18, 18, 18, 18, 18, 1,
|
||||
1, 1, 1, 1, 1, 1, 1,
|
||||
},
|
||||
fsggml.TensorTypeQ5_1: {
|
||||
0, 60, 0, 0, 0, 0, 255, 255, 0, 17, 34, 51, 68,
|
||||
85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60,
|
||||
0, 80, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102,
|
||||
119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 84,
|
||||
0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136,
|
||||
153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 86, 0, 0,
|
||||
255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170,
|
||||
187, 204, 221, 238, 255, 0, 60, 0, 88, 0, 0, 255, 255,
|
||||
0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204,
|
||||
221, 238, 255, 0, 60, 0, 89, 0, 0, 255, 255, 0, 17,
|
||||
34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238,
|
||||
255, 0, 60, 0, 90, 0, 0, 255, 255, 0, 17, 34, 51,
|
||||
68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0,
|
||||
60, 0, 91, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85,
|
||||
102, 119, 136, 153, 170, 187, 204, 221, 238, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ8_0: {
|
||||
208, 51, 0, 4, 8, 12, 16, 20, 25, 29, 33, 37, 41,
|
||||
45, 49, 53, 57, 61, 66, 70, 74, 78, 82, 86, 90, 94,
|
||||
98, 102, 107, 111, 115, 119, 123, 127, 240, 55, 65, 67, 69,
|
||||
71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95,
|
||||
97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121,
|
||||
123, 125, 127, 252, 57, 86, 87, 88, 90, 91, 92, 94, 95,
|
||||
96, 98, 99, 100, 102, 103, 104, 106, 107, 108, 110, 111, 112,
|
||||
114, 115, 116, 118, 119, 120, 122, 123, 124, 126, 127, 0, 60,
|
||||
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
|
||||
109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
|
||||
122, 123, 124, 125, 126, 127, 2, 61, 102, 103, 104, 105, 105,
|
||||
106, 107, 108, 109, 109, 110, 111, 112, 113, 113, 114, 115, 116,
|
||||
117, 117, 118, 119, 120, 121, 121, 122, 123, 124, 125, 125, 126,
|
||||
127, 4, 62, 106, 107, 108, 108, 109, 110, 110, 111, 112, 112,
|
||||
113, 114, 114, 115, 116, 116, 117, 118, 118, 119, 120, 120, 121,
|
||||
122, 122, 123, 124, 124, 125, 126, 126, 127, 6, 63, 109, 110,
|
||||
110, 111, 112, 112, 113, 113, 114, 114, 115, 116, 116, 117, 117,
|
||||
118, 118, 119, 120, 120, 121, 121, 122, 122, 123, 124, 124, 125,
|
||||
125, 126, 126, 127, 4, 64, 112, 112, 113, 113, 114, 114, 115,
|
||||
115, 116, 116, 117, 117, 118, 118, 119, 119, 120, 120, 121, 121,
|
||||
122, 122, 123, 123, 124, 124, 125, 125, 126, 126, 127, 127,
|
||||
},
|
||||
fsggml.TensorTypeBF16: {
|
||||
0, 0, 128, 63, 0, 64, 64, 64, 128, 64, 160, 64, 192,
|
||||
64, 224, 64, 0, 65, 16, 65, 32, 65, 48, 65, 64, 65,
|
||||
80, 65, 96, 65, 112, 65, 128, 65, 136, 65, 144, 65, 152,
|
||||
65, 160, 65, 168, 65, 176, 65, 184, 65, 192, 65, 200, 65,
|
||||
208, 65, 216, 65, 224, 65, 232, 65, 240, 65, 248, 65, 0,
|
||||
66, 4, 66, 8, 66, 12, 66, 16, 66, 20, 66, 24, 66,
|
||||
28, 66, 32, 66, 36, 66, 40, 66, 44, 66, 48, 66, 52,
|
||||
66, 56, 66, 60, 66, 64, 66, 68, 66, 72, 66, 76, 66,
|
||||
80, 66, 84, 66, 88, 66, 92, 66, 96, 66, 100, 66, 104,
|
||||
66, 108, 66, 112, 66, 116, 66, 120, 66, 124, 66, 128, 66,
|
||||
130, 66, 132, 66, 134, 66, 136, 66, 138, 66, 140, 66, 142,
|
||||
66, 144, 66, 146, 66, 148, 66, 150, 66, 152, 66, 154, 66,
|
||||
156, 66, 158, 66, 160, 66, 162, 66, 164, 66, 166, 66, 168,
|
||||
66, 170, 66, 172, 66, 174, 66, 176, 66, 178, 66, 180, 66,
|
||||
182, 66, 184, 66, 186, 66, 188, 66, 190, 66, 192, 66, 194,
|
||||
66, 196, 66, 198, 66, 200, 66, 202, 66, 204, 66, 206, 66,
|
||||
208, 66, 210, 66, 212, 66, 214, 66, 216, 66, 218, 66, 220,
|
||||
66, 222, 66, 224, 66, 226, 66, 228, 66, 230, 66, 232, 66,
|
||||
234, 66, 236, 66, 238, 66, 240, 66, 242, 66, 244, 66, 246,
|
||||
66, 248, 66, 250, 66, 252, 66, 254, 66, 0, 67, 1, 67,
|
||||
2, 67, 3, 67, 4, 67, 5, 67, 6, 67, 7, 67, 8,
|
||||
67, 9, 67, 10, 67, 11, 67, 12, 67, 13, 67, 14, 67,
|
||||
15, 67, 16, 67, 17, 67, 18, 67, 19, 67, 20, 67, 21,
|
||||
67, 22, 67, 23, 67, 24, 67, 25, 67, 26, 67, 27, 67,
|
||||
28, 67, 29, 67, 30, 67, 31, 67, 32, 67, 33, 67, 34,
|
||||
67, 35, 67, 36, 67, 37, 67, 38, 67, 39, 67, 40, 67,
|
||||
41, 67, 42, 67, 43, 67, 44, 67, 45, 67, 46, 67, 47,
|
||||
67, 48, 67, 49, 67, 50, 67, 51, 67, 52, 67, 53, 67,
|
||||
54, 67, 55, 67, 56, 67, 57, 67, 58, 67, 59, 67, 60,
|
||||
67, 61, 67, 62, 67, 63, 67, 64, 67, 65, 67, 66, 67,
|
||||
67, 67, 68, 67, 69, 67, 70, 67, 71, 67, 72, 67, 73,
|
||||
67, 74, 67, 75, 67, 76, 67, 77, 67, 78, 67, 79, 67,
|
||||
80, 67, 81, 67, 82, 67, 83, 67, 84, 67, 85, 67, 86,
|
||||
67, 87, 67, 88, 67, 89, 67, 90, 67, 91, 67, 92, 67,
|
||||
93, 67, 94, 67, 95, 67, 96, 67, 97, 67, 98, 67, 99,
|
||||
67, 100, 67, 101, 67, 102, 67, 103, 67, 104, 67, 105, 67,
|
||||
106, 67, 107, 67, 108, 67, 109, 67, 110, 67, 111, 67, 112,
|
||||
67, 113, 67, 114, 67, 115, 67, 116, 67, 117, 67, 118, 67,
|
||||
119, 67, 120, 67, 121, 67, 122, 67, 123, 67, 124, 67, 125,
|
||||
67, 126, 67, 127, 67,
|
||||
},
|
||||
fsggml.TensorTypeF16: {
|
||||
0, 0, 0, 60, 0, 64, 0, 66, 0, 68, 0, 69, 0, 70, 0, 71, 0,
|
||||
72, 128, 72, 0, 73, 128, 73, 0, 74, 128, 74, 0, 75, 128, 75,
|
||||
0, 76, 64, 76, 128, 76, 192, 76, 0, 77, 64, 77, 128, 77, 192,
|
||||
77, 0, 78, 64, 78, 128, 78, 192, 78, 0, 79, 64, 79, 128, 79,
|
||||
192, 79, 0, 80, 32, 80, 64, 80, 96, 80, 128, 80, 160, 80,
|
||||
192, 80, 224, 80, 0, 81, 32, 81, 64, 81, 96, 81, 128, 81,
|
||||
160, 81, 192, 81, 224, 81, 0, 82, 32, 82, 64, 82, 96, 82,
|
||||
128, 82, 160, 82, 192, 82, 224, 82, 0, 83, 32, 83, 64, 83,
|
||||
96, 83, 128, 83, 160, 83, 192, 83, 224, 83, 0, 84, 16, 84,
|
||||
32, 84, 48, 84, 64, 84, 80, 84, 96, 84, 112, 84, 128, 84,
|
||||
144, 84, 160, 84, 176, 84, 192, 84, 208, 84, 224, 84, 240,
|
||||
84, 0, 85, 16, 85, 32, 85, 48, 85, 64, 85, 80, 85, 96, 85,
|
||||
112, 85, 128, 85, 144, 85, 160, 85, 176, 85, 192, 85, 208,
|
||||
85, 224, 85, 240, 85, 0, 86, 16, 86, 32, 86, 48, 86, 64,
|
||||
86, 80, 86, 96, 86, 112, 86, 128, 86, 144, 86, 160, 86,
|
||||
176, 86, 192, 86, 208, 86, 224, 86, 240, 86, 0, 87, 16,
|
||||
87, 32, 87, 48, 87, 64, 87, 80, 87, 96, 87, 112, 87, 128,
|
||||
87, 144, 87, 160, 87, 176, 87, 192, 87, 208, 87, 224, 87,
|
||||
240, 87, 0, 88, 8, 88, 16, 88, 24, 88, 32, 88, 40, 88,
|
||||
48, 88, 56, 88, 64, 88, 72, 88, 80, 88, 88, 88, 96, 88,
|
||||
104, 88, 112, 88, 120, 88, 128, 88, 136, 88, 144, 88, 152,
|
||||
88, 160, 88, 168, 88, 176, 88, 184, 88, 192, 88, 200, 88,
|
||||
208, 88, 216, 88, 224, 88, 232, 88, 240, 88, 248, 88, 0,
|
||||
89, 8, 89, 16, 89, 24, 89, 32, 89, 40, 89, 48, 89, 56, 89,
|
||||
64, 89, 72, 89, 80, 89, 88, 89, 96, 89, 104, 89, 112, 89,
|
||||
120, 89, 128, 89, 136, 89, 144, 89, 152, 89, 160, 89, 168,
|
||||
89, 176, 89, 184, 89, 192, 89, 200, 89, 208, 89, 216, 89,
|
||||
224, 89, 232, 89, 240, 89, 248, 89, 0, 90, 8, 90, 16, 90,
|
||||
24, 90, 32, 90, 40, 90, 48, 90, 56, 90, 64, 90, 72, 90, 80,
|
||||
90, 88, 90, 96, 90, 104, 90, 112, 90, 120, 90, 128, 90,
|
||||
136, 90, 144, 90, 152, 90, 160, 90, 168, 90, 176, 90, 184,
|
||||
90, 192, 90, 200, 90, 208, 90, 216, 90, 224, 90, 232, 90,
|
||||
240, 90, 248, 90, 0, 91, 8, 91, 16, 91, 24, 91, 32, 91, 40,
|
||||
91, 48, 91, 56, 91, 64, 91, 72, 91, 80, 91, 88, 91, 96, 91,
|
||||
104, 91, 112, 91, 120, 91, 128, 91, 136, 91, 144, 91, 152,
|
||||
91, 160, 91, 168, 91, 176, 91, 184, 91, 192, 91, 200, 91,
|
||||
208, 91, 216, 91, 224, 91, 232, 91, 240, 91, 248, 91,
|
||||
},
|
||||
fsggml.TensorTypeF32: {
|
||||
0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128,
|
||||
64, 0, 0, 160, 64, 0, 0, 192, 64, 0, 0, 224, 64, 0, 0, 0, 65, 0,
|
||||
0, 16, 65, 0, 0, 32, 65, 0, 0, 48, 65, 0, 0, 64, 65, 0, 0, 80, 65,
|
||||
0, 0, 96, 65, 0, 0, 112, 65, 0, 0, 128, 65, 0, 0, 136, 65, 0, 0,
|
||||
144, 65, 0, 0, 152, 65, 0, 0, 160, 65, 0, 0, 168, 65, 0, 0, 176,
|
||||
65, 0, 0, 184, 65, 0, 0, 192, 65, 0, 0, 200, 65, 0, 0, 208, 65, 0,
|
||||
0, 216, 65, 0, 0, 224, 65, 0, 0, 232, 65, 0, 0, 240, 65, 0, 0, 248,
|
||||
65, 0, 0, 0, 66, 0, 0, 4, 66, 0, 0, 8, 66, 0, 0, 12, 66, 0, 0, 16,
|
||||
66, 0, 0, 20, 66, 0, 0, 24, 66, 0, 0, 28, 66, 0, 0, 32, 66, 0, 0,
|
||||
36, 66, 0, 0, 40, 66, 0, 0, 44, 66, 0, 0, 48, 66, 0, 0, 52, 66, 0,
|
||||
0, 56, 66, 0, 0, 60, 66, 0, 0, 64, 66, 0, 0, 68, 66, 0, 0, 72, 66,
|
||||
0, 0, 76, 66, 0, 0, 80, 66, 0, 0, 84, 66, 0, 0, 88, 66, 0, 0, 92, 66,
|
||||
0, 0, 96, 66, 0, 0, 100, 66, 0, 0, 104, 66, 0, 0, 108, 66, 0, 0, 112,
|
||||
66, 0, 0, 116, 66, 0, 0, 120, 66, 0, 0, 124, 66, 0, 0, 128, 66, 0, 0,
|
||||
130, 66, 0, 0, 132, 66, 0, 0, 134, 66, 0, 0, 136, 66, 0, 0, 138, 66,
|
||||
0, 0, 140, 66, 0, 0, 142, 66, 0, 0, 144, 66, 0, 0, 146, 66, 0, 0, 148,
|
||||
66, 0, 0, 150, 66, 0, 0, 152, 66, 0, 0, 154, 66, 0, 0, 156, 66, 0, 0,
|
||||
158, 66, 0, 0, 160, 66, 0, 0, 162, 66, 0, 0, 164, 66, 0, 0, 166, 66,
|
||||
0, 0, 168, 66, 0, 0, 170, 66, 0, 0, 172, 66, 0, 0, 174, 66, 0, 0, 176,
|
||||
66, 0, 0, 178, 66, 0, 0, 180, 66, 0, 0, 182, 66, 0, 0, 184, 66, 0, 0,
|
||||
186, 66, 0, 0, 188, 66, 0, 0, 190, 66, 0, 0, 192, 66, 0, 0, 194, 66, 0,
|
||||
0, 196, 66, 0, 0, 198, 66, 0, 0, 200, 66, 0, 0, 202, 66, 0, 0, 204, 66,
|
||||
0, 0, 206, 66, 0, 0, 208, 66, 0, 0, 210, 66, 0, 0, 212, 66, 0, 0, 214, 66,
|
||||
0, 0, 216, 66, 0, 0, 218, 66, 0, 0, 220, 66, 0, 0, 222, 66, 0, 0, 224, 66,
|
||||
0, 0, 226, 66, 0, 0, 228, 66, 0, 0, 230, 66, 0, 0, 232, 66, 0, 0, 234, 66,
|
||||
0, 0, 236, 66, 0, 0, 238, 66, 0, 0, 240, 66, 0, 0, 242, 66, 0, 0, 244, 66,
|
||||
0, 0, 246, 66, 0, 0, 248, 66, 0, 0, 250, 66, 0, 0, 252, 66, 0, 0, 254, 66,
|
||||
0, 0, 0, 67, 0, 0, 1, 67, 0, 0, 2, 67, 0, 0, 3, 67, 0, 0, 4, 67, 0, 0, 5, 67,
|
||||
0, 0, 6, 67, 0, 0, 7, 67, 0, 0, 8, 67, 0, 0, 9, 67, 0, 0, 10, 67, 0, 0, 11,
|
||||
67, 0, 0, 12, 67, 0, 0, 13, 67, 0, 0, 14, 67, 0, 0, 15, 67, 0, 0, 16, 67,
|
||||
0, 0, 17, 67, 0, 0, 18, 67, 0, 0, 19, 67, 0, 0, 20, 67, 0, 0, 21, 67, 0, 0,
|
||||
22, 67, 0, 0, 23, 67, 0, 0, 24, 67, 0, 0, 25, 67, 0, 0, 26, 67, 0, 0, 27,
|
||||
67, 0, 0, 28, 67, 0, 0, 29, 67, 0, 0, 30, 67, 0, 0, 31, 67, 0, 0, 32, 67,
|
||||
0, 0, 33, 67, 0, 0, 34, 67, 0, 0, 35, 67, 0, 0, 36, 67, 0, 0, 37, 67, 0, 0,
|
||||
38, 67, 0, 0, 39, 67, 0, 0, 40, 67, 0, 0, 41, 67, 0, 0, 42, 67, 0, 0, 43, 67,
|
||||
0, 0, 44, 67, 0, 0, 45, 67, 0, 0, 46, 67, 0, 0, 47, 67, 0, 0, 48, 67, 0, 0,
|
||||
49, 67, 0, 0, 50, 67, 0, 0, 51, 67, 0, 0, 52, 67, 0, 0, 53, 67, 0, 0, 54, 67,
|
||||
0, 0, 55, 67, 0, 0, 56, 67, 0, 0, 57, 67, 0, 0, 58, 67, 0, 0, 59, 67, 0, 0,
|
||||
60, 67, 0, 0, 61, 67, 0, 0, 62, 67, 0, 0, 63, 67, 0, 0, 64, 67, 0, 0, 65, 67,
|
||||
0, 0, 66, 67, 0, 0, 67, 67, 0, 0, 68, 67, 0, 0, 69, 67, 0, 0, 70, 67, 0, 0, 71,
|
||||
67, 0, 0, 72, 67, 0, 0, 73, 67, 0, 0, 74, 67, 0, 0, 75, 67, 0, 0, 76, 67, 0,
|
||||
0, 77, 67, 0, 0, 78, 67, 0, 0, 79, 67, 0, 0, 80, 67, 0, 0, 81, 67, 0, 0, 82,
|
||||
67, 0, 0, 83, 67, 0, 0, 84, 67, 0, 0, 85, 67, 0, 0, 86, 67, 0, 0, 87, 67, 0,
|
||||
0, 88, 67, 0, 0, 89, 67, 0, 0, 90, 67, 0, 0, 91, 67, 0, 0, 92, 67, 0, 0, 93,
|
||||
67, 0, 0, 94, 67, 0, 0, 95, 67, 0, 0, 96, 67, 0, 0, 97, 67, 0, 0, 98, 67, 0,
|
||||
0, 99, 67, 0, 0, 100, 67, 0, 0, 101, 67, 0, 0, 102, 67, 0, 0, 103, 67, 0, 0,
|
||||
104, 67, 0, 0, 105, 67, 0, 0, 106, 67, 0, 0, 107, 67, 0, 0, 108, 67, 0, 0, 109,
|
||||
67, 0, 0, 110, 67, 0, 0, 111, 67, 0, 0, 112, 67, 0, 0, 113, 67, 0, 0, 114, 67,
|
||||
0, 0, 115, 67, 0, 0, 116, 67, 0, 0, 117, 67, 0, 0, 118, 67, 0, 0, 119, 67, 0,
|
||||
0, 120, 67, 0, 0, 121, 67, 0, 0, 122, 67, 0, 0, 123, 67, 0, 0, 124, 67, 0, 0,
|
||||
125, 67, 0, 0, 126, 67, 0, 0, 127, 67,
|
||||
},
|
||||
fsggml.TensorTypeQ4_K: {
|
||||
52, 52, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 8, 15, 128,
|
||||
128, 129, 129, 146, 146, 147, 147, 164, 164, 165, 165, 166, 182,
|
||||
183, 183, 184, 200, 201, 201, 202, 218, 218, 219, 219, 236, 236,
|
||||
237, 237, 254, 254, 255, 202, 202, 202, 203, 203, 203, 219, 219,
|
||||
219, 220, 220, 220, 220, 220, 236, 237, 237, 237, 237, 237,
|
||||
237, 237, 238, 254, 254, 254, 254, 254, 255, 255, 255, 255, 220,
|
||||
220, 220, 220, 221, 221, 221, 221, 221, 221, 221, 237, 237, 237,
|
||||
238, 238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 255, 255,
|
||||
255, 255, 255, 255, 255, 237, 237, 237, 237, 237, 237, 237, 238,
|
||||
238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 254, 254, 254,
|
||||
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ2_K: {
|
||||
1, 2, 3, 3, 4, 5, 7, 7, 8, 9, 10, 11, 12, 13, 14, 15, 184, 184,
|
||||
184, 185, 249, 249, 249, 249, 249, 250, 250, 254, 254, 254, 254,
|
||||
255, 253, 253, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
|
||||
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 171, 69, 0, 0,
|
||||
},
|
||||
fsggml.TensorTypeQ5_K: {
|
||||
32, 48, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 7, 15, 254,
|
||||
254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
|
||||
254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 0, 1, 2, 19, 20, 37, 38, 55, 56, 73, 74,
|
||||
91, 92, 109, 110, 127, 112, 128, 129, 146, 147, 164, 165, 182, 183,
|
||||
200, 201, 218, 219, 236, 237, 254, 133, 133, 149, 150, 150, 150,
|
||||
167, 167, 167, 168, 184, 184, 185, 185, 201, 202, 202, 202, 219,
|
||||
219, 219, 219, 236, 236, 236, 237, 253, 253, 254, 254, 254, 255,
|
||||
169, 169, 169, 169, 186, 186, 186, 186, 186, 187, 187, 203, 203,
|
||||
203, 204, 204, 204, 220, 220, 221, 221, 221, 221, 237, 237, 238,
|
||||
238, 238, 238, 254, 255, 255, 203, 203, 203, 204, 204, 204, 204,
|
||||
204, 220, 220, 220, 221, 221, 221, 221, 221, 237, 237, 238, 238,
|
||||
238, 238, 238, 238, 254, 255, 255, 255, 255, 255, 255, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ6_K: {
|
||||
96, 110, 92, 90, 88, 70, 68, 50, 48, 46, 44, 42, 24, 22, 4, 2, 80,
|
||||
95, 78, 77, 76, 59, 58, 57, 40, 39, 38, 21, 20, 19, 2, 1, 75, 75,
|
||||
74, 57, 57, 56, 55, 39, 38, 37, 21, 20, 20, 19, 2, 2, 72, 55, 55,
|
||||
54, 54, 37, 37, 36, 36, 19, 19, 18, 18, 1, 1, 0, 35, 35, 35, 35,
|
||||
34, 18, 18, 18, 17, 17, 17, 1, 1, 0, 0, 0, 35, 35, 34, 34, 18,
|
||||
18, 18, 17, 17, 17, 17, 1, 0, 0, 0, 0, 35, 35, 35, 19, 19, 18, 18,
|
||||
18, 18, 18, 1, 1, 1, 1, 1, 1, 34, 34, 18, 18, 18, 18, 17, 17, 17,
|
||||
17, 1, 1, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 248, 240, 231, 224, 216, 208, 200, 192, 184, 176,
|
||||
166, 160, 152, 144, 136, 128, 235, 43,
|
||||
},
|
||||
fsggml.TensorTypeQ3_K: {
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 20, 23, 23, 7, 7, 6, 6, 6, 2,
|
||||
1, 1, 1, 1, 0, 0, 22, 22, 6, 6, 5, 5, 5, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 204, 170, 136, 102, 68,
|
||||
34, 1, 5, 5, 5, 5, 189, 63,
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -1169,6 +1170,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||
|
||||
r := gin.Default()
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.Use(
|
||||
cors.New(corsConfig),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
@@ -1512,6 +1514,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
@@ -1640,3 +1643,23 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
finalUserIndex := -1
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "user" {
|
||||
finalUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
|
||||
var stream bool = false
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []ggml.Tensor) (string, string) {
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
@@ -99,7 +99,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []ggml.Tensor{
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
@@ -158,7 +158,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []ggml.Tensor{})
|
||||
}, []*ggml.Tensor{})
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Files: map[string]string{"bert.gguf": digest},
|
||||
@@ -299,9 +299,6 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"num_ctx": 1024,
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -324,9 +321,6 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"num_ctx": 1024,
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -350,9 +344,6 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "user", Content: "Help me write tests."},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"num_ctx": 1024,
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -640,7 +631,7 @@ func TestGenerate(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
@@ -652,7 +643,7 @@ func TestGenerate(t *testing.T) {
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []ggml.Tensor{
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
@@ -707,7 +698,7 @@ func TestGenerate(t *testing.T) {
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []ggml.Tensor{})
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -473,14 +474,24 @@ func TestRoutes(t *testing.T) {
|
||||
t.Fatalf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var retrieveResp api.RetrieveModelResponse
|
||||
err = json.Unmarshal(body, &retrieveResp)
|
||||
var m openai.Model
|
||||
err = json.Unmarshal(body, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal response body: %v", err)
|
||||
}
|
||||
|
||||
if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
|
||||
t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
|
||||
if m.Id != "show-model" || m.OwnedBy != "library" {
|
||||
t.Errorf("expected model 'show-model' owned by 'library', got %v", m)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Method Not Allowed",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/show",
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
if resp.StatusCode != 405 {
|
||||
t.Errorf("expected status code 405, got %d", resp.StatusCode)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -516,7 +527,7 @@ func TestRoutes(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
u := httpSrv.URL + tc.Path
|
||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||
req, err := http.NewRequestWithContext(t.Context(), tc.Method, u, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
@@ -746,3 +757,128 @@ func TestNormalize(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterThinkTags(t *testing.T) {
|
||||
type testCase struct {
|
||||
msgs []api.Message
|
||||
want []api.Message
|
||||
model *Model
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
// with newlines inside the think tag aned newlines after
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc\ndef"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
// should leave thinking tags if it's after the last user message
|
||||
{
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking...</think>after"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
|
||||
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "after"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
|
||||
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "qwen3",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// shouldn't strip anything because the model family isn't one of the hardcoded ones
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Config: ConfigV2{
|
||||
ModelFamily: "llama3",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// deepseek-r1:-prefixed model
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
want: []api.Message{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "assistant", Content: "abc"},
|
||||
{Role: "user", Content: "What is the answer?"},
|
||||
},
|
||||
model: &Model{
|
||||
Name: "registry.ollama.ai/library/deepseek-r1:latest",
|
||||
ShortName: "deepseek-r1:7b",
|
||||
Config: ConfigV2{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
filtered := filterThinkTags(tc.msgs, tc.model)
|
||||
|
||||
if !reflect.DeepEqual(filtered, tc.want) {
|
||||
t.Errorf("messages differ for case %d:", i)
|
||||
for i := range tc.want {
|
||||
if i >= len(filtered) {
|
||||
t.Errorf(" missing message %d: %+v", i, tc.want[i])
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
|
||||
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
|
||||
}
|
||||
}
|
||||
if len(filtered) > len(tc.want) {
|
||||
for i := len(tc.want); i < len(filtered); i++ {
|
||||
t.Errorf(" extra message %d: %+v", i, filtered[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
146
server/sched.go
146
server/sched.go
@@ -81,6 +81,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: model,
|
||||
@@ -110,11 +114,6 @@ func (s *Scheduler) Run(ctx context.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
const (
|
||||
defaultContextLength = 4096
|
||||
smallGpuContextLength = 2048
|
||||
)
|
||||
|
||||
func (s *Scheduler) processPending(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
@@ -148,6 +147,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil {
|
||||
if runner.needsReload(ctx, pending) {
|
||||
slog.Debug("reloading", "runner", runner)
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
@@ -167,17 +167,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
gpus = s.getGpuFn()
|
||||
}
|
||||
|
||||
if pending.origNumCtx == -1 {
|
||||
if len(gpus) == 1 && gpus[0].Library != "cpu" && gpus[0].TotalMemory <= 4096*1024*1024 {
|
||||
slog.Info("GPU is small, limiting default context window", "num_ctx", smallGpuContextLength)
|
||||
pending.opts.NumCtx = smallGpuContextLength
|
||||
pending.origNumCtx = smallGpuContextLength
|
||||
} else {
|
||||
pending.opts.NumCtx = defaultContextLength
|
||||
pending.origNumCtx = defaultContextLength
|
||||
}
|
||||
}
|
||||
|
||||
if envconfig.MaxRunners() <= 0 {
|
||||
// No user specified MaxRunners, so figure out what automatic setting to use
|
||||
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
|
||||
@@ -294,7 +283,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
// Trigger an expiration to unload once it's done
|
||||
runnerToExpire.refMu.Lock()
|
||||
slog.Debug("resetting model to expire immediately to make room", "modelPath", runnerToExpire.modelPath, "refCount", runnerToExpire.refCount)
|
||||
slog.Debug("resetting model to expire immediately to make room", "runner", runnerToExpire, "refCount", runnerToExpire.refCount)
|
||||
if runnerToExpire.expireTimer != nil {
|
||||
runnerToExpire.expireTimer.Stop()
|
||||
runnerToExpire.expireTimer = nil
|
||||
@@ -307,13 +296,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
// Wait for the unload to happen
|
||||
// Note: at this point we're queueing up all incoming requests, even if they were for
|
||||
// a different model that's loaded and not scheduled to be removed.
|
||||
slog.Debug("waiting for pending requests to complete and unload to occur", "modelPath", runnerToExpire.modelPath)
|
||||
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("shutting down scheduler pending loop")
|
||||
return
|
||||
case <-s.unloadedCh:
|
||||
slog.Debug("unload completed", "modelPath", runnerToExpire.modelPath)
|
||||
slog.Debug("unload completed", "runner", runnerToExpire)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -343,16 +332,16 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
runner.refCount--
|
||||
if runner.refCount <= 0 {
|
||||
if runner.sessionDuration <= 0 {
|
||||
slog.Debug("runner with zero duration has gone idle, expiring to unload", "modelPath", runner.modelPath)
|
||||
slog.Debug("runner with zero duration has gone idle, expiring to unload", "runner", runner)
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
s.expiredCh <- runner
|
||||
} else if runner.expireTimer == nil {
|
||||
slog.Debug("runner with non-zero duration has gone idle, adding timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
|
||||
slog.Debug("runner with non-zero duration has gone idle, adding timer", "runner", runner, "duration", runner.sessionDuration)
|
||||
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
|
||||
slog.Debug("timer expired, expiring to unload", "modelPath", runner.modelPath)
|
||||
slog.Debug("timer expired, expiring to unload", "runner", runner)
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
if runner.expireTimer != nil {
|
||||
@@ -363,18 +352,18 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
})
|
||||
runner.expiresAt = time.Now().Add(runner.sessionDuration)
|
||||
} else {
|
||||
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
|
||||
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "runner", runner, "duration", runner.sessionDuration)
|
||||
runner.expireTimer.Reset(runner.sessionDuration)
|
||||
runner.expiresAt = time.Now().Add(runner.sessionDuration)
|
||||
}
|
||||
}
|
||||
slog.Debug("after processing request finished event", "modelPath", runner.modelPath, "refCount", runner.refCount)
|
||||
slog.Debug("after processing request finished event", "runner", runner, "refCount", runner.refCount)
|
||||
runner.refMu.Unlock()
|
||||
case runner := <-s.expiredCh:
|
||||
slog.Debug("runner expired event received", "modelPath", runner.modelPath)
|
||||
slog.Debug("runner expired event received", "runner", runner)
|
||||
runner.refMu.Lock()
|
||||
if runner.refCount > 0 {
|
||||
slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
|
||||
slog.Debug("expired event with positive ref count, retrying", "runner", runner, "refCount", runner.refCount)
|
||||
go func(runner *runnerRef) {
|
||||
// We can't unload yet, but want to as soon as the current request completes
|
||||
// So queue up another expired event
|
||||
@@ -386,17 +375,29 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload", "modelPath", runner.modelPath)
|
||||
finished := runner.waitForVRAMRecovery()
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner released", "modelPath", runner.modelPath)
|
||||
runner.refMu.Unlock()
|
||||
|
||||
<-finished
|
||||
slog.Debug("sending an unloaded event", "modelPath", runner.modelPath)
|
||||
s.unloadedCh <- struct{}{}
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelPath]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
// request is canceled and we're trying to load another model
|
||||
// that requires this one to be evicted, or the settings change
|
||||
// and require a reload
|
||||
s.loadedMu.Unlock()
|
||||
runner.refMu.Unlock()
|
||||
slog.Debug("duplicate expired event, ignoring", "runner", runner)
|
||||
} else {
|
||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||
finished := runner.waitForVRAMRecovery()
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
runner.refMu.Unlock()
|
||||
slog.Debug("sending an unloaded event", "runner", runner)
|
||||
s.unloadedCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -418,7 +419,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
pending.successCh <- runner
|
||||
go func() {
|
||||
<-pending.ctx.Done()
|
||||
slog.Debug("context for request finished")
|
||||
slog.Debug("context for request finished", "runner", runner)
|
||||
finished <- pending
|
||||
}()
|
||||
}
|
||||
@@ -453,12 +454,19 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
estimatedVRAM: llama.EstimatedVRAM(),
|
||||
estimatedTotal: llama.EstimatedTotal(),
|
||||
loading: true,
|
||||
refCount: 1,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
runner.numParallel = numParallel
|
||||
runner.refMu.Lock()
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
oldRunner.unload()
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
@@ -467,13 +475,16 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
defer runner.refMu.Unlock()
|
||||
if err = llama.WaitUntilRunning(req.ctx); err != nil {
|
||||
slog.Error("error loading llama server", "error", err)
|
||||
runner.refCount--
|
||||
req.errCh <- err
|
||||
slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
|
||||
slog.Debug("triggering expiration for failed load", "runner", runner)
|
||||
s.expiredCh <- runner
|
||||
return
|
||||
}
|
||||
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
|
||||
slog.Debug("finished setting up", "runner", runner)
|
||||
if runner.pid < 0 {
|
||||
runner.pid = llama.Pid()
|
||||
}
|
||||
runner.refCount++
|
||||
runner.loading = false
|
||||
go func() {
|
||||
<-req.ctx.Done()
|
||||
@@ -491,7 +502,12 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
}
|
||||
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||
s.loadedMu.Lock()
|
||||
runners := make([]*runnerRef, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runners = append(runners, r)
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
for _, r := range runners {
|
||||
r.refMu.Lock()
|
||||
if r.llama != nil {
|
||||
for _, gpu := range allGpus {
|
||||
@@ -502,7 +518,6 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
}
|
||||
r.refMu.Unlock()
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
|
||||
for i := range allGpus {
|
||||
@@ -549,12 +564,11 @@ func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList)
|
||||
|
||||
// TODO consolidate sched_types.go
|
||||
type runnerRef struct {
|
||||
refMu sync.Mutex
|
||||
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
|
||||
refMu sync.Mutex
|
||||
refCount uint // prevent unloading if > 0
|
||||
// unloading bool // set to true when we are trying to unload the runner
|
||||
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus discover.GpuInfoList // Recorded at time of provisioning
|
||||
estimatedVRAM uint64
|
||||
@@ -639,6 +653,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
|
||||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
|
||||
finished <- struct{}{}
|
||||
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
|
||||
return finished
|
||||
}
|
||||
start := time.Now()
|
||||
@@ -657,7 +672,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
for {
|
||||
<-ticker.C
|
||||
if time.Now().After(expiresAt) {
|
||||
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "model", runner.modelPath)
|
||||
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
}
|
||||
|
||||
@@ -670,7 +685,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
}
|
||||
// If we're within ~80% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "model", runner.modelPath)
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
@@ -679,6 +694,33 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
return finished
|
||||
}
|
||||
|
||||
func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
}
|
||||
if len(runner.gpus) > 0 {
|
||||
attrs = append(attrs,
|
||||
slog.String("inference", runner.gpus[0].Library),
|
||||
slog.Int("devices", len(runner.gpus)),
|
||||
)
|
||||
}
|
||||
attrs = append(attrs,
|
||||
slog.String("size", format.HumanBytes2(runner.estimatedTotal)),
|
||||
slog.String("vram", format.HumanBytes2(runner.estimatedVRAM)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
type ByDurationAndName []*runnerRef
|
||||
|
||||
func (a ByDurationAndName) Len() int { return len(a) }
|
||||
@@ -801,12 +843,12 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
||||
rc := runner.refCount
|
||||
runner.refMu.Unlock()
|
||||
if rc == 0 {
|
||||
slog.Debug("found an idle runner to unload")
|
||||
slog.Debug("found an idle runner to unload", "runner", runner)
|
||||
return runner
|
||||
}
|
||||
}
|
||||
// None appear idle, just wait for the one with the shortest duration
|
||||
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
|
||||
slog.Debug("no idle runners, picking the shortest duration", "runner_count", len(runnerList), "runner", runnerList[0])
|
||||
return runnerList[0]
|
||||
}
|
||||
|
||||
@@ -823,8 +865,8 @@ func (s *Scheduler) unloadAllRunners() {
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
s.loadedMu.Lock()
|
||||
defer s.loadedMu.Unlock()
|
||||
runner, ok := s.loaded[model.ModelPath]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
runner.expiresAt = time.Now()
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestInitScheduler(t *testing.T) {
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
ctx, done := context.WithCancel(t.Context())
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
@@ -35,7 +35,7 @@ func TestInitScheduler(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
var f *ggml.GGML // value not used in tests
|
||||
@@ -126,7 +126,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
"tokenizer.ggml.tokens": []string{" "},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []ggml.Tensor{
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
}))
|
||||
@@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
b.req.opts.NumCtx = 4096
|
||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
return b
|
||||
}
|
||||
@@ -168,7 +167,7 @@ func getCpuFn() discover.GpuInfoList {
|
||||
}
|
||||
|
||||
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
@@ -211,7 +210,7 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
@@ -259,7 +258,7 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
@@ -356,7 +355,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer done()
|
||||
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
@@ -409,7 +408,7 @@ func TestGetRunner(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
req := &LlmRequest{
|
||||
@@ -456,7 +455,7 @@ func TestExpireRunner(t *testing.T) {
|
||||
|
||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||
func TestPrematureExpired(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
@@ -503,7 +502,7 @@ func TestPrematureExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUseLoadedRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
@@ -530,7 +529,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateFreeSpace(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
gpus := discover.GpuInfoList{
|
||||
{
|
||||
@@ -563,7 +562,7 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
gpus := discover.GpuInfoList{
|
||||
{
|
||||
@@ -597,7 +596,7 @@ func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
|
||||
@@ -617,7 +616,7 @@ func TestFindRunnerToUnload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNeedsReload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
@@ -664,7 +663,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnloadAllRunners(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
@@ -696,7 +695,7 @@ func TestUnload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAlreadyCanceled(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
dctx, done2 := context.WithCancel(ctx)
|
||||
done2()
|
||||
@@ -713,7 +712,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHomogeneousGPUs(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
|
||||
@@ -793,3 +792,4 @@ func (s *mockLlm) Close() error {
|
||||
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
|
||||
func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }
|
||||
func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
|
||||
Reference in New Issue
Block a user