Merge branch 'main' into drifkin/array-head-count-simple

This commit is contained in:
Devon Rifkin
2025-05-08 11:46:52 -07:00
committed by GitHub
156 changed files with 6327 additions and 3282 deletions

View File

@@ -15,6 +15,7 @@ import (
"path/filepath"
"slices"
"strings"
"sync/atomic"
"github.com/gin-gonic/gin"
@@ -23,7 +24,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -425,9 +425,14 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
ft := layer.GGML.KV().FileType()
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType)})
want, err := ggml.ParseFileType(quantizeType)
var doneBytes atomic.Uint64
totalBytes := uint64(layer.Size) - layer.GGML.Tensors().Offset
fnWrap := func(n uint64) {
done := doneBytes.Add(n)
progress := float32(done) / float32(totalBytes)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
}
ftype, err := ggml.ParseFileType(quantizeType)
if err != nil {
return nil, err
}
@@ -436,6 +441,11 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
if err != nil {
return nil, err
}
fp, err := os.Open(blob)
if err != nil {
return nil, err
}
defer fp.Close()
temp, err := os.CreateTemp(filepath.Dir(blob), quantizeType)
if err != nil {
@@ -444,15 +454,15 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
defer temp.Close()
defer os.Remove(temp.Name())
if err := llama.Quantize(blob, temp.Name(), uint32(want)); err != nil {
if err := quantize(fp, temp, layer.GGML, ftype, fnWrap); err != nil {
return nil, err
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
if _, err := temp.Seek(0, io.SeekStart); err != nil {
return nil, err
}
@@ -462,7 +472,6 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err
}
return &layerGGML{newLayer, f}, nil
}

View File

@@ -106,6 +106,11 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(m.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
return capabilities
}

View File

@@ -3,6 +3,7 @@ package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
@@ -91,11 +92,7 @@ func createMockGGUFData(architecture string, vision bool) []byte {
func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempDir := t.TempDir()
// Create different types of mock model files
completionModelPath := filepath.Join(tempDir, "model.bin")
@@ -104,21 +101,13 @@ func TestModelCapabilities(t *testing.T) {
// Create a simple model file for tests that don't depend on GGUF content
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
}
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
if err := errors.Join(
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
@@ -236,27 +225,18 @@ func TestModelCapabilities(t *testing.T) {
func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempDir := t.TempDir()
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create vision model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
if err := errors.Join(
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")

View File

@@ -1,224 +0,0 @@
// safetensors provides a reader for the safetensor directories and files.
package safetensors
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"slices"
"strconv"
"strings"
)
// Tensor represents a single tensor in a safetensors file.
//
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
//
// It is not safe for use across multiple goroutines.
type Tensor struct {
name string
dataType string
shape []int64
fsys fs.FS
fname string // entry name in fsys
offset int64
size int64
}
type Model struct {
fsys fs.FS
}
func Read(fsys fs.FS) (*Model, error) {
return &Model{fsys: fsys}, nil
}
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
return func(yield func(*Tensor, error) bool) {
entries, err := fs.Glob(m.fsys, "*.safetensors")
if err != nil {
yield(nil, err)
return
}
for _, e := range entries {
tt, err := m.readTensors(e)
if err != nil {
yield(nil, err)
return
}
for _, t := range tt {
if !yield(t, nil) {
return
}
}
}
}
}
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
f, err := m.fsys.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
finfo, err := f.Stat()
if err != nil {
return nil, err
}
headerSize, err := readInt64(f)
if err != nil {
return nil, err
}
data := make([]byte, headerSize)
_, err = io.ReadFull(f, data)
if err != nil {
return nil, err
}
var raws map[string]json.RawMessage
if err := json.Unmarshal(data, &raws); err != nil {
return nil, err
}
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
// own and if it needs to follow each tensor. Currently, I (bmizerany)
// am only seeing them show up with one entry for file type which is
// always "pt".
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if name == "__metadata__" {
// TODO(bmizerany): do something with metadata?
continue
}
var v struct {
DataType string `json:"dtype"`
Shape []int64 `json:"shape"`
Offsets []int64 `json:"data_offsets"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
}
if len(v.Offsets) != 2 {
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
}
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin := endOfHeader + v.Offsets[0]
end := endOfHeader + v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
tt = append(tt, &Tensor{
name: name,
dataType: v.DataType,
shape: v.Shape,
fsys: m.fsys,
fname: fname,
offset: begin,
size: end - begin,
})
}
return tt, nil
}
func checkBeginEnd(size, begin, end int64) error {
if begin < 0 {
return fmt.Errorf("begin must not be negative: %d", begin)
}
if end < 0 {
return fmt.Errorf("end must not be negative: %d", end)
}
if end < begin {
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
}
if end > size {
return fmt.Errorf("end must be <= size: %d > %d", end, size)
}
return nil
}
func readInt64(r io.Reader) (int64, error) {
var v uint64
var buf [8]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return 0, err
}
for i := range buf {
v |= uint64(buf[i]) << (8 * i)
}
return int64(v), nil
}
type Shape []int64
func (s Shape) String() string {
var b strings.Builder
b.WriteByte('[')
for i, v := range s {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(strconv.FormatInt(v, 10))
}
b.WriteByte(']')
return b.String()
}
func (t *Tensor) Name() string { return t.name }
func (t *Tensor) DataType() string { return t.dataType }
func (t *Tensor) Size() int64 { return t.size }
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
func (t *Tensor) Reader() (io.ReadCloser, error) {
f, err := t.fsys.Open(t.fname)
if err != nil {
return nil, err
}
r := newSectionReader(f, t.offset, t.size)
rc := struct {
io.Reader
io.Closer
}{r, f}
return rc, nil
}
// newSectionReader returns a new io.Reader that reads from r starting at
// offset. It is a convenience function for creating a io.SectionReader when r
// may not be an io.ReaderAt.
//
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
// beginning of the file.
//
// If r is an io.Seeker,
// or slow path. The slow path is used when r does not implement io.ReaderAt,
// in which case it must discard the data it reads.
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
if r, ok := r.(io.ReaderAt); ok {
return io.NewSectionReader(r, offset, n)
}
if r, ok := r.(io.ReadSeeker); ok {
r.Seek(offset, io.SeekStart)
return io.LimitReader(r, n)
}
// Discard to offset and return a limited reader.
_, err := io.CopyN(io.Discard, r, offset)
if err != nil {
return nil
}
return io.LimitReader(r, n)
}

View File

@@ -1,375 +0,0 @@
package main
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
"golang.org/x/sync/errgroup"
)
var stdout io.Writer = os.Stdout
const usage = `Opp is a tool for pushing and pulling Ollama models.
Usage:
opp [flags] <push|pull|import>
Commands:
push Upload a model to the Ollama server.
pull Download a model from the Ollama server.
import Import a model from a local safetensor directory.
Examples:
# Pull a model from the Ollama server.
opp pull library/llama3.2:latest
# Push a model to the Ollama server.
opp push username/my_model:8b
# Import a model from a local safetensor directory.
opp import /path/to/safetensor
Envionment Variables:
OLLAMA_MODELS
The directory where models are pushed and pulled from
(default ~/.ollama/models).
`
func main() {
flag.Usage = func() {
fmt.Fprint(os.Stderr, usage)
}
flag.Parse()
ctx := context.Background()
err := func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPull(ctx, rc)
case "push":
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPush(ctx, rc)
case "import":
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
return cmdImport(ctx, c)
default:
if cmd == "" {
flag.Usage()
} else {
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
}
os.Exit(2)
return errors.New("unreachable")
}
}()
if err != nil {
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
os.Exit(1)
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
os.Exit(1)
}
tr := http.DefaultTransport.(*http.Transport).Clone()
// TODO(bmizerany): configure transport?
rc.HTTPClient = &http.Client{Transport: tr}
var mu sync.Mutex
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
var pb bytes.Buffer
printProgress := func() {
pb.Reset()
mu.Lock()
for d, s := range p {
// Write progress to a buffer first to avoid blocking
// on stdout while holding the lock.
stamp := time.Now().Format("2006/01/02 15:04:05")
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
if s[0] == s[1] {
delete(p, d)
}
}
mu.Unlock()
io.Copy(stdout, &pb)
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
return
}
mu.Lock()
p[l.Digest] = [2]int64{l.Size, n}
mu.Unlock()
},
})
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, model)
}()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
printProgress()
case err := <-errc:
printProgress()
return err
}
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
flag.PrintDefaults()
}
flag.Parse(args)
model := flag.Arg(0)
if model == "" {
return fmt.Errorf("missing model argument")
}
from := cmp.Or(*flagFrom, model)
m, err := rc.ResolveLocal(from)
if err != nil {
return err
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
switch {
case errors.Is(err, ollama.ErrCached):
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
case err != nil:
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
case n == 0:
l := m.Layer(l.Digest)
mt, p, _ := mime.ParseMediaType(l.MediaType)
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
switch mt {
case "tensor":
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
default:
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
}
}
},
})
return rc.Push(ctx, model, &ollama.PushParams{
From: from,
})
}
type trackingReader struct {
io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.n.Add(int64(n))
return n, err
}
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("import", flag.ExitOnError)
flagAs := flag.String("as", "", "Import using the provided name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
flag.PrintDefaults()
}
flag.Parse(args)
if *flagAs == "" {
return fmt.Errorf("missing -as flag")
}
as := ollama.CompleteName(*flagAs)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
m, err := safetensors.Read(os.DirFS(dir))
if err != nil {
return err
}
var total int64
var tt []*safetensors.Tensor
for t, err := range m.Tensors() {
if err != nil {
return err
}
tt = append(tt, t)
total += t.Size()
}
var n atomic.Int64
done := make(chan error)
go func() {
layers := make([]*ollama.Layer, len(tt))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
var ctxErr error
for i, t := range tt {
if ctx.Err() != nil {
// The context may cancel AFTER we exit the
// loop, and so if we use ctx.Err() after the
// loop we may report it as the error that
// broke the loop, when it was not. This can
// manifest as a false-negative, leading the
// user to think their import failed when it
// did not, so capture it if and only if we
// exit the loop because of a ctx.Err() and
// report it.
ctxErr = ctx.Err()
break
}
g.Go(func() (err error) {
rc, err := t.Reader()
if err != nil {
return err
}
defer rc.Close()
tr := &trackingReader{rc, &n}
d, err := c.Import(tr, t.Size())
if err != nil {
return err
}
if err := rc.Close(); err != nil {
return err
}
layers[i] = &ollama.Layer{
Digest: d,
Size: t.Size(),
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
"name": t.Name(),
"dtype": t.DataType(),
"shape": t.Shape().String(),
}),
}
return nil
})
}
done <- func() error {
if err := errors.Join(g.Wait(), ctxErr); err != nil {
return err
}
m := &ollama.Manifest{Layers: layers}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return err
}
d := blob.DigestFromBytes(data)
err = blob.PutBytes(c, d, data)
if err != nil {
return err
}
return c.Link(as, d)
}()
}()
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
csiHideCursor(stdout)
defer csiShowCursor(stdout)
csiSavePos(stdout)
writeProgress := func() {
csiRestorePos(stdout)
nn := n.Load()
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
formatNatural(nn),
formatNatural(total),
nn*100/total,
ansiClearToEndOfLine,
)
}
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
writeProgress()
case err := <-done:
writeProgress()
fmt.Println()
fmt.Println("Successfully imported", as)
return err
}
}
}
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d B", n)
case n < 1024*1024:
return fmt.Sprintf("%.1f KB", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
}
}
const ansiClearToEndOfLine = "\033[K"
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }

View File

@@ -3,7 +3,6 @@
package backoff
import (
"context"
"testing"
"testing/synctest"
"time"
@@ -29,7 +28,7 @@ func TestLoopAllocs(t *testing.T) {
}
func BenchmarkLoop(b *testing.B) {
ctx := context.Background()
ctx := b.Context()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
if n == b.N {

View File

@@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
defer blob.Close()
f, _, err := ggml.Decode(blob, 1024)
f, _, err := ggml.Decode(blob, -1)
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package server
import (
"os"
"path/filepath"
"testing"
@@ -11,9 +10,7 @@ import (
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test")
require.NoError(t, err)
defer os.RemoveAll(dir)
tempDir := t.TempDir()
tests := []struct {
name string
@@ -24,19 +21,19 @@ func TestGetBlobsPath(t *testing.T) {
{
"empty digest",
"",
filepath.Join(dir, "blobs"),
filepath.Join(tempDir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
@@ -60,7 +57,7 @@ func TestGetBlobsPath(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir)
t.Setenv("OLLAMA_MODELS", tempDir)
got, err := GetBlobsPath(tc.digest)

View File

@@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"image"
"image/png"
"testing"
@@ -318,7 +317,7 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

274
server/quantization.go Normal file
View File

@@ -0,0 +1,274 @@
package server
import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"strings"
"unsafe"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml/backend/ggml"
)
type quantizer struct {
*os.File
offset uint64
from, to *fsggml.Tensor
progressFn func(n uint64)
}
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
quantize := q.from.Kind != q.to.Kind
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
if !quantize {
n, err := io.Copy(w, sr)
q.progressFn(q.from.Size())
return n, err
}
data, err := io.ReadAll(sr)
if err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
}
data = ggml.Quantize(newType, f32s, q.from.Shape)
n, err := w.Write(data)
q.progressFn(q.from.Size())
return int64(n), err
}
type quantizeState struct {
nAttnV int // Number of attn_*v* weight tensors
nFfnDown int // Number of ffn_down tensors
iAttnV int // Running counter of number of attn_v tensors that have been processed
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
}
func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
// Ported from llama_tensor_get_type, removed unsupported quantization types
nExperts := max(1, kv.Uint("expert_count", 0))
if name == "output.weight" || name == "output_norm.weight" || (!qs.hasOutput && name == "token_embd.weight") {
nx := shape[0]
qk_k := newType.BlockSize()
if nx%qk_k != 0 {
newType = fsggml.TensorTypeQ8_0
} else if newType != fsggml.TensorTypeQ8_0 {
newType = fsggml.TensorTypeQ6_K
}
} else if strings.Contains(name, "attn_v.weight") {
if ftype == fsggml.FileTypeQ2_K {
if kv.GQA() >= 4 {
newType = fsggml.TensorTypeQ4_K
} else {
newType = fsggml.TensorTypeQ3_K
}
} else if ftype == fsggml.FileTypeQ2_K_S && kv.GQA() >= 4 {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ3_K_M {
if qs.iAttnV < 2 {
newType = fsggml.TensorTypeQ5_K
} else {
newType = fsggml.TensorTypeQ4_K
}
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K
} else if (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ5_K_M) &&
useMoreBits(qs.iAttnV, qs.nAttnV) {
newType = fsggml.TensorTypeQ6_K
} else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
newType = fsggml.TensorTypeQ5_K
}
// TODO
// if (qs.model.type == LLM_TYPE_70B) {
// // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
// // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
// // nearly negligible increase in model size by quantizing this tensor with more bits:
// if (newType == GGML_TYPE_Q3_K || newType == GGML_TYPE_Q4_K) newType = GGML_TYPE_Q5_K;
// }
if nExperts == 8 {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
qs.iAttnV++
} else if strings.Contains(name, "attn_k.weight") {
if nExperts == 8 {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
} else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown
n_layer := qs.nFfnDown
if ftype == fsggml.FileTypeQ2_K {
newType = fsggml.TensorTypeQ3_K
} else if ftype == fsggml.FileTypeQ2_K_S {
if iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ4_K
}
} else if ftype == fsggml.FileTypeQ3_K_M {
if iLayer < n_layer/16 {
newType = fsggml.TensorTypeQ5_K
} else if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ4_K
} else {
newType = fsggml.TensorTypeQ3_K
}
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K
} else if ftype == fsggml.FileTypeQ4_K_M {
if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K
}
} else if ftype == fsggml.FileTypeQ5_K_M && useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ5_K
}
qs.iFfnDown++
} else if strings.Contains(name, "attn_output.weight") {
if nExperts == 8 {
if ftype == fsggml.FileTypeQ2_K || ftype == fsggml.FileTypeQ3_K_S || ftype == fsggml.FileTypeQ3_K_M ||
ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
} else {
if ftype == fsggml.FileTypeQ2_K {
newType = fsggml.TensorTypeQ3_K
} else if ftype == fsggml.FileTypeQ3_K_M {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K
}
}
} else if strings.Contains(name, "attn_qkv.weight") {
if ftype == fsggml.FileTypeQ3_K_M || ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
} else if ftype == fsggml.FileTypeQ5_K_M {
newType = fsggml.TensorTypeQ6_K
}
}
if newType.IsQuantized() {
nx := shape[0]
ny := uint64(1)
if len(shape) > 1 {
ny = shape[1]
}
qk_k := newType.BlockSize()
if nx%qk_k != 0 {
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String()))
newType = fsggml.TensorTypeF16
}
}
return newType
}
func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType, progressFn func(n uint64)) error {
kv := maps.Clone(orig.KV())
kv["general.file_type"] = newFileType
// kv["general.quantization_version"] = ggml.QuantizationVersion()
qs := &quantizeState{}
// Build up the quantize state so newType can adjust types
layerCount := 0
for k, l := range orig.Tensors().GroupLayers() {
if strings.HasPrefix(k, "blk.") {
layerCount++
}
for _, tensor := range l {
if strings.Contains(tensor.Name, "attn_v.weight") ||
strings.Contains(tensor.Name, "attn_qkv.weight") ||
strings.Contains(tensor.Name, "attn_kv_b.weight") {
qs.nAttnV++
} else if tensor.Name == "output.weight" {
qs.hasOutput = true
}
}
}
qs.nFfnDown = layerCount
origTensors := orig.Tensors().Items()
outputTensors := make([]*fsggml.Tensor, len(origTensors))
for i, tensor := range origTensors {
tensor := tensor
newType := newType(tensor, kv, qs, newFileType)
newTensor := &fsggml.Tensor{
Name: tensor.Name,
Shape: tensor.Shape,
Kind: uint32(newType),
}
outputTensors[i] = newTensor
outputTensors[i].WriterTo = quantizer{
File: in,
offset: orig.Tensors().Offset + tensor.Offset,
from: tensor,
to: newTensor,
progressFn: progressFn,
}
}
return fsggml.WriteGGUF(out, kv, outputTensors)
}
func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.FileType) fsggml.TensorType {
defaultType := ftype.ToTensorType()
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision stuff
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
quantize = quantize && (len(t.Shape) >= 2)
// do not quantize norm tensors
quantize = quantize && !strings.Contains(name, "_norm.weight")
// do not quantize expert gating tensors
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
// do not quantize positional embeddings and token types (BERT)
quantize = quantize && (name != "position_embd.weight")
quantize = quantize && (name != "token_types.weight")
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w2.weight")
quantize = quantize && !strings.Contains(name, "time_mix_decay_w1.weight")
quantize = quantize && !strings.Contains(name, "time_mix_decay_w2.weight")
quantize = quantize && !strings.Contains(name, "time_mix_lerp_fused.weight")
// do not quantize relative position bias (T5)
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
newType := fsggml.TensorType(t.Kind)
if quantize {
// get more optimal quantization type based on the tensor shape, layer, etc.
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
if newType != defaultType {
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
}
}
return newType
}

882
server/quantization_test.go Normal file
View File

@@ -0,0 +1,882 @@
package server
import (
"bytes"
"fmt"
"math"
"os"
"strings"
"testing"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml/backend/ggml"
)
func TestGetTensorNewType(t *testing.T) {
cases := []struct {
name string
kv map[string]any
qs quantizeState
newType fsggml.TensorType
tensor_name string
shape []uint64
ftype fsggml.FileType
expected fsggml.TensorType
expectedPanic string
}{
{
name: "output_unsupported",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "output.weight",
shape: []uint64{100, 100},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeF16,
},
{
name: "output_Q8",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "output.weight",
shape: []uint64{1024, 1024},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "attn_v.weight_q4_k",
kv: map[string]any{
"general.architecture": "foo",
"foo.attention.head_count": uint32(4),
"foo.attention.head_count_kv": uint32(1),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "attn_v.weight_q2_k_s_q4_k",
kv: map[string]any{
"general.architecture": "foo",
"foo.attention.head_count": uint32(4),
"foo.attention.head_count_kv": uint32(1),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k_m",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_v.weight_q3_k_m_i",
qs: quantizeState{
iAttnV: 2,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k_l",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_v.weight_q4_k_m",
qs: quantizeState{
iAttnV: 2,
nAttnV: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "attn_v.weight_q4_k_s",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_S,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_v.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ8_0,
},
{
name: "attn_k.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_k.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ8_0,
},
{
name: "ffn_down_q2_k",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "ffn_down_q2_k_s",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_0,
},
{
name: "ffn_down_q2_k_s_layers",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "ffn_down_q3_k_m_base",
qs: quantizeState{
iFfnDown: 1,
nFfnDown: 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "ffn_down_q3_k_m_16",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 16,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "ffn_down_q3_k_m_8",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "ffn_down_q3_k_l",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "ffn_down_q4_k_m",
qs: quantizeState{
iFfnDown: 1,
nFfnDown: 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_0,
},
{
name: "ffn_down_q4_k_m_6",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "ffn_down_q5_k_m",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ5_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "ffn_down_q4_k_s",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_S,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_output.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_output.weight_q2",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "attn_output.weight_q3_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_output.weight_q3_k_l",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_qkv.weight_q3_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_qkv.weight_q4_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_qkv.weight_q5_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ5_K_M,
expected: fsggml.TensorTypeQ6_K,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
if tt.expectedPanic != "" {
defer func() {
e := recover()
if !strings.Contains(fmt.Sprintf("%v", e), tt.expectedPanic) {
t.Fatalf("incorrect panic\ngot: %v\nexpected: %s", e, tt.expectedPanic)
}
}()
} else {
defer func() {
e := recover()
if e != nil {
t.Fatalf("hit unexpected panic %v", e)
}
}()
}
ret := getTensorNewType(tt.kv, &tt.qs, tt.newType, tt.tensor_name, tt.shape, tt.ftype)
if ret != tt.expected {
t.Fatalf("incorrect type returned\ngot: %d\nexpected: %d", ret, tt.expected)
}
})
}
}
func TestQuantizeModel(t *testing.T) {
cases := []struct {
name string
kv map[string]any
tensors []*fsggml.Tensor
newType string
expectedTensorTypes map[string]fsggml.TensorType
}{
{
name: "f16_q4_k",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{256, 4},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
},
newType: "Q4_K",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn.weight": fsggml.TensorTypeQ4_K,
"output.weight": fsggml.TensorTypeQ6_K,
},
},
{
name: "f32_q4_k",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn_v.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512},
WriterTo: bytes.NewReader(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...)),
},
},
newType: "Q4_K",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn_v.weight": fsggml.TensorTypeQ6_K,
"output.weight": fsggml.TensorTypeF32,
},
},
{
name: "f16_q8_0",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{32, 16, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{256, 4},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
},
newType: "Q8_0",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn.weight": fsggml.TensorTypeQ8_0,
"output.weight": fsggml.TensorTypeQ8_0,
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), tt.name)
if err != nil {
t.Fatal(err.Error())
}
defer f.Close()
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
if err != nil {
t.Fatalf("failed to create initial model: %s", err)
}
fp, err := os.Open(f.Name())
if err != nil {
t.Fatal(err.Error())
}
defer fp.Close()
meta, _, err := fsggml.Decode(fp, -1)
if err != nil {
t.Fatal(err.Error())
}
progressCalled := false
progress := func(n uint64) {
// fmt.Fprintf(os.Stderr, "progress: %f\n", p)
progressCalled = true
}
tmp, err := os.CreateTemp(t.TempDir(), tt.name+".out")
if err != nil {
t.Fatal(err.Error())
}
defer tmp.Close()
ftype, err := fsggml.ParseFileType(tt.newType)
if err != nil {
t.Fatal(err.Error())
}
err = quantize(fp, tmp, meta, ftype, progress)
if err != nil {
t.Fatalf("error during quantize: %s", err)
}
if !progressCalled {
t.Fatalf("progress was not reported")
}
// Now attempt to load it back and make sure types match expected
fpNew, err := os.Open(tmp.Name())
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
defer fpNew.Close()
newMeta, _, err := fsggml.Decode(fpNew, -1)
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
tensors := newMeta.Tensors()
for _, l := range tensors.GroupLayers() {
for _, tensor := range l {
if fsggml.TensorType(tensor.Kind) != tt.expectedTensorTypes[tensor.Name] {
t.Fatalf("incorrect output type for %s\ngot:%s\nexpected:%s", tensor.Name, fsggml.TensorType(tensor.Kind), tt.expectedTensorTypes[tensor.Name])
}
}
}
})
}
}
func TestConvertToF32(t *testing.T) {
expected := make([]float32, 256)
for i := range expected {
expected[i] = float32(i)
}
for dtype, data := range quantBytes {
// Skip the no-op
if dtype == fsggml.TensorTypeF32 {
continue
}
t.Run(dtype.String(), func(t *testing.T) {
fp32 := ggml.ConvertToF32(data, uint32(dtype), 256)
similarity := cosineSimilarity(expected, fp32)
if similarity < 0.999 {
t.Fatalf("Results not similar enough: %s %f", dtype.String(), similarity)
}
})
}
}
func dotProduct[V float32 | float64](v1, v2 []V) V {
var result V = 0
for i := range v1 {
result += v1[i] * v2[i]
}
return result
}
func magnitude[V float32 | float64](v []V) V {
var result V = 0
for _, val := range v {
result += val * val
}
return V(math.Sqrt(float64(result)))
}
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
}
// Precomputed quantized data - arange 256
// # For gguf-py supported types
// import gguf
// import numpy as np
// print(repr(gguf.quantize(np.arange(256, dtype=np.float16), gguf.GGMLQuantizationType.Q4_0)))
//
// For types not supported by gguf-py converted via ggml_fp32_to_fp16_row and quantize_XXX
//
// data := make([]byte, 256*2)
// fp32 := make([]float32, 256)
// for i := range 256 {
// fp32[i] = float32(i)
// }
// l := C.quantize_q6_K((*C.float)(&fp32[0]), unsafe.Pointer(&data[0]), 1, 256, nil)
// for i := range data[:int(l)] {
// fmt.Printf("%d, ", data[i])
// }
var (
quantBytes = map[fsggml.TensorType][]byte{
fsggml.TensorTypeQ4_0: {
192, 195, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
21, 21, 21, 4, 4, 224, 199, 36, 36, 36, 36, 19, 19,
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 240, 201, 19,
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
1, 1, 240, 203, 18, 18, 18, 18, 18, 18, 18, 18, 1,
1, 1, 1, 1, 1, 1, 1, 248, 204, 18, 18, 17, 17,
17, 17, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 248,
205, 17, 17, 17, 17, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 248, 206, 17, 17, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 248, 207, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1,
},
fsggml.TensorTypeQ4_1: {
34, 64, 0, 0, 128, 128, 145, 145, 162, 162, 179, 179, 196,
196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 80, 128, 128,
145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230, 247,
247, 34, 64, 0, 84, 128, 128, 145, 145, 162, 162, 179, 179,
196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 86, 128,
128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230,
247, 247, 34, 64, 0, 88, 128, 128, 145, 145, 162, 162, 179,
179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 89,
128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230,
230, 247, 247, 34, 64, 0, 90, 128, 128, 145, 145, 162, 162,
179, 179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0,
91, 128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213,
230, 230, 247, 247,
},
fsggml.TensorTypeQ5_0: {
192, 191, 1, 0, 0, 0, 128, 127, 127, 110, 110, 93, 93,
76, 76, 59, 59, 42, 42, 25, 25, 8, 224, 195, 0, 0,
0, 0, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
21, 21, 21, 4, 4, 240, 197, 0, 0, 0, 0, 53, 37,
37, 37, 37, 36, 36, 20, 20, 20, 20, 19, 19, 3, 3,
3, 240, 199, 0, 0, 0, 0, 36, 36, 36, 36, 19, 19,
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 248, 200, 0,
0, 0, 0, 35, 19, 19, 19, 19, 19, 19, 18, 18, 18,
18, 2, 2, 2, 2, 2, 248, 201, 0, 0, 0, 0, 19,
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
1, 1, 248, 202, 0, 0, 0, 0, 18, 18, 18, 18, 18,
18, 18, 18, 18, 2, 2, 1, 1, 1, 1, 1, 248, 203,
0, 0, 0, 0, 18, 18, 18, 18, 18, 18, 18, 18, 1,
1, 1, 1, 1, 1, 1, 1,
},
fsggml.TensorTypeQ5_1: {
0, 60, 0, 0, 0, 0, 255, 255, 0, 17, 34, 51, 68,
85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60,
0, 80, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102,
119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 84,
0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136,
153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 86, 0, 0,
255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170,
187, 204, 221, 238, 255, 0, 60, 0, 88, 0, 0, 255, 255,
0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204,
221, 238, 255, 0, 60, 0, 89, 0, 0, 255, 255, 0, 17,
34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238,
255, 0, 60, 0, 90, 0, 0, 255, 255, 0, 17, 34, 51,
68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0,
60, 0, 91, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85,
102, 119, 136, 153, 170, 187, 204, 221, 238, 255,
},
fsggml.TensorTypeQ8_0: {
208, 51, 0, 4, 8, 12, 16, 20, 25, 29, 33, 37, 41,
45, 49, 53, 57, 61, 66, 70, 74, 78, 82, 86, 90, 94,
98, 102, 107, 111, 115, 119, 123, 127, 240, 55, 65, 67, 69,
71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95,
97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121,
123, 125, 127, 252, 57, 86, 87, 88, 90, 91, 92, 94, 95,
96, 98, 99, 100, 102, 103, 104, 106, 107, 108, 110, 111, 112,
114, 115, 116, 118, 119, 120, 122, 123, 124, 126, 127, 0, 60,
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
122, 123, 124, 125, 126, 127, 2, 61, 102, 103, 104, 105, 105,
106, 107, 108, 109, 109, 110, 111, 112, 113, 113, 114, 115, 116,
117, 117, 118, 119, 120, 121, 121, 122, 123, 124, 125, 125, 126,
127, 4, 62, 106, 107, 108, 108, 109, 110, 110, 111, 112, 112,
113, 114, 114, 115, 116, 116, 117, 118, 118, 119, 120, 120, 121,
122, 122, 123, 124, 124, 125, 126, 126, 127, 6, 63, 109, 110,
110, 111, 112, 112, 113, 113, 114, 114, 115, 116, 116, 117, 117,
118, 118, 119, 120, 120, 121, 121, 122, 122, 123, 124, 124, 125,
125, 126, 126, 127, 4, 64, 112, 112, 113, 113, 114, 114, 115,
115, 116, 116, 117, 117, 118, 118, 119, 119, 120, 120, 121, 121,
122, 122, 123, 123, 124, 124, 125, 125, 126, 126, 127, 127,
},
fsggml.TensorTypeBF16: {
0, 0, 128, 63, 0, 64, 64, 64, 128, 64, 160, 64, 192,
64, 224, 64, 0, 65, 16, 65, 32, 65, 48, 65, 64, 65,
80, 65, 96, 65, 112, 65, 128, 65, 136, 65, 144, 65, 152,
65, 160, 65, 168, 65, 176, 65, 184, 65, 192, 65, 200, 65,
208, 65, 216, 65, 224, 65, 232, 65, 240, 65, 248, 65, 0,
66, 4, 66, 8, 66, 12, 66, 16, 66, 20, 66, 24, 66,
28, 66, 32, 66, 36, 66, 40, 66, 44, 66, 48, 66, 52,
66, 56, 66, 60, 66, 64, 66, 68, 66, 72, 66, 76, 66,
80, 66, 84, 66, 88, 66, 92, 66, 96, 66, 100, 66, 104,
66, 108, 66, 112, 66, 116, 66, 120, 66, 124, 66, 128, 66,
130, 66, 132, 66, 134, 66, 136, 66, 138, 66, 140, 66, 142,
66, 144, 66, 146, 66, 148, 66, 150, 66, 152, 66, 154, 66,
156, 66, 158, 66, 160, 66, 162, 66, 164, 66, 166, 66, 168,
66, 170, 66, 172, 66, 174, 66, 176, 66, 178, 66, 180, 66,
182, 66, 184, 66, 186, 66, 188, 66, 190, 66, 192, 66, 194,
66, 196, 66, 198, 66, 200, 66, 202, 66, 204, 66, 206, 66,
208, 66, 210, 66, 212, 66, 214, 66, 216, 66, 218, 66, 220,
66, 222, 66, 224, 66, 226, 66, 228, 66, 230, 66, 232, 66,
234, 66, 236, 66, 238, 66, 240, 66, 242, 66, 244, 66, 246,
66, 248, 66, 250, 66, 252, 66, 254, 66, 0, 67, 1, 67,
2, 67, 3, 67, 4, 67, 5, 67, 6, 67, 7, 67, 8,
67, 9, 67, 10, 67, 11, 67, 12, 67, 13, 67, 14, 67,
15, 67, 16, 67, 17, 67, 18, 67, 19, 67, 20, 67, 21,
67, 22, 67, 23, 67, 24, 67, 25, 67, 26, 67, 27, 67,
28, 67, 29, 67, 30, 67, 31, 67, 32, 67, 33, 67, 34,
67, 35, 67, 36, 67, 37, 67, 38, 67, 39, 67, 40, 67,
41, 67, 42, 67, 43, 67, 44, 67, 45, 67, 46, 67, 47,
67, 48, 67, 49, 67, 50, 67, 51, 67, 52, 67, 53, 67,
54, 67, 55, 67, 56, 67, 57, 67, 58, 67, 59, 67, 60,
67, 61, 67, 62, 67, 63, 67, 64, 67, 65, 67, 66, 67,
67, 67, 68, 67, 69, 67, 70, 67, 71, 67, 72, 67, 73,
67, 74, 67, 75, 67, 76, 67, 77, 67, 78, 67, 79, 67,
80, 67, 81, 67, 82, 67, 83, 67, 84, 67, 85, 67, 86,
67, 87, 67, 88, 67, 89, 67, 90, 67, 91, 67, 92, 67,
93, 67, 94, 67, 95, 67, 96, 67, 97, 67, 98, 67, 99,
67, 100, 67, 101, 67, 102, 67, 103, 67, 104, 67, 105, 67,
106, 67, 107, 67, 108, 67, 109, 67, 110, 67, 111, 67, 112,
67, 113, 67, 114, 67, 115, 67, 116, 67, 117, 67, 118, 67,
119, 67, 120, 67, 121, 67, 122, 67, 123, 67, 124, 67, 125,
67, 126, 67, 127, 67,
},
fsggml.TensorTypeF16: {
0, 0, 0, 60, 0, 64, 0, 66, 0, 68, 0, 69, 0, 70, 0, 71, 0,
72, 128, 72, 0, 73, 128, 73, 0, 74, 128, 74, 0, 75, 128, 75,
0, 76, 64, 76, 128, 76, 192, 76, 0, 77, 64, 77, 128, 77, 192,
77, 0, 78, 64, 78, 128, 78, 192, 78, 0, 79, 64, 79, 128, 79,
192, 79, 0, 80, 32, 80, 64, 80, 96, 80, 128, 80, 160, 80,
192, 80, 224, 80, 0, 81, 32, 81, 64, 81, 96, 81, 128, 81,
160, 81, 192, 81, 224, 81, 0, 82, 32, 82, 64, 82, 96, 82,
128, 82, 160, 82, 192, 82, 224, 82, 0, 83, 32, 83, 64, 83,
96, 83, 128, 83, 160, 83, 192, 83, 224, 83, 0, 84, 16, 84,
32, 84, 48, 84, 64, 84, 80, 84, 96, 84, 112, 84, 128, 84,
144, 84, 160, 84, 176, 84, 192, 84, 208, 84, 224, 84, 240,
84, 0, 85, 16, 85, 32, 85, 48, 85, 64, 85, 80, 85, 96, 85,
112, 85, 128, 85, 144, 85, 160, 85, 176, 85, 192, 85, 208,
85, 224, 85, 240, 85, 0, 86, 16, 86, 32, 86, 48, 86, 64,
86, 80, 86, 96, 86, 112, 86, 128, 86, 144, 86, 160, 86,
176, 86, 192, 86, 208, 86, 224, 86, 240, 86, 0, 87, 16,
87, 32, 87, 48, 87, 64, 87, 80, 87, 96, 87, 112, 87, 128,
87, 144, 87, 160, 87, 176, 87, 192, 87, 208, 87, 224, 87,
240, 87, 0, 88, 8, 88, 16, 88, 24, 88, 32, 88, 40, 88,
48, 88, 56, 88, 64, 88, 72, 88, 80, 88, 88, 88, 96, 88,
104, 88, 112, 88, 120, 88, 128, 88, 136, 88, 144, 88, 152,
88, 160, 88, 168, 88, 176, 88, 184, 88, 192, 88, 200, 88,
208, 88, 216, 88, 224, 88, 232, 88, 240, 88, 248, 88, 0,
89, 8, 89, 16, 89, 24, 89, 32, 89, 40, 89, 48, 89, 56, 89,
64, 89, 72, 89, 80, 89, 88, 89, 96, 89, 104, 89, 112, 89,
120, 89, 128, 89, 136, 89, 144, 89, 152, 89, 160, 89, 168,
89, 176, 89, 184, 89, 192, 89, 200, 89, 208, 89, 216, 89,
224, 89, 232, 89, 240, 89, 248, 89, 0, 90, 8, 90, 16, 90,
24, 90, 32, 90, 40, 90, 48, 90, 56, 90, 64, 90, 72, 90, 80,
90, 88, 90, 96, 90, 104, 90, 112, 90, 120, 90, 128, 90,
136, 90, 144, 90, 152, 90, 160, 90, 168, 90, 176, 90, 184,
90, 192, 90, 200, 90, 208, 90, 216, 90, 224, 90, 232, 90,
240, 90, 248, 90, 0, 91, 8, 91, 16, 91, 24, 91, 32, 91, 40,
91, 48, 91, 56, 91, 64, 91, 72, 91, 80, 91, 88, 91, 96, 91,
104, 91, 112, 91, 120, 91, 128, 91, 136, 91, 144, 91, 152,
91, 160, 91, 168, 91, 176, 91, 184, 91, 192, 91, 200, 91,
208, 91, 216, 91, 224, 91, 232, 91, 240, 91, 248, 91,
},
fsggml.TensorTypeF32: {
0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128,
64, 0, 0, 160, 64, 0, 0, 192, 64, 0, 0, 224, 64, 0, 0, 0, 65, 0,
0, 16, 65, 0, 0, 32, 65, 0, 0, 48, 65, 0, 0, 64, 65, 0, 0, 80, 65,
0, 0, 96, 65, 0, 0, 112, 65, 0, 0, 128, 65, 0, 0, 136, 65, 0, 0,
144, 65, 0, 0, 152, 65, 0, 0, 160, 65, 0, 0, 168, 65, 0, 0, 176,
65, 0, 0, 184, 65, 0, 0, 192, 65, 0, 0, 200, 65, 0, 0, 208, 65, 0,
0, 216, 65, 0, 0, 224, 65, 0, 0, 232, 65, 0, 0, 240, 65, 0, 0, 248,
65, 0, 0, 0, 66, 0, 0, 4, 66, 0, 0, 8, 66, 0, 0, 12, 66, 0, 0, 16,
66, 0, 0, 20, 66, 0, 0, 24, 66, 0, 0, 28, 66, 0, 0, 32, 66, 0, 0,
36, 66, 0, 0, 40, 66, 0, 0, 44, 66, 0, 0, 48, 66, 0, 0, 52, 66, 0,
0, 56, 66, 0, 0, 60, 66, 0, 0, 64, 66, 0, 0, 68, 66, 0, 0, 72, 66,
0, 0, 76, 66, 0, 0, 80, 66, 0, 0, 84, 66, 0, 0, 88, 66, 0, 0, 92, 66,
0, 0, 96, 66, 0, 0, 100, 66, 0, 0, 104, 66, 0, 0, 108, 66, 0, 0, 112,
66, 0, 0, 116, 66, 0, 0, 120, 66, 0, 0, 124, 66, 0, 0, 128, 66, 0, 0,
130, 66, 0, 0, 132, 66, 0, 0, 134, 66, 0, 0, 136, 66, 0, 0, 138, 66,
0, 0, 140, 66, 0, 0, 142, 66, 0, 0, 144, 66, 0, 0, 146, 66, 0, 0, 148,
66, 0, 0, 150, 66, 0, 0, 152, 66, 0, 0, 154, 66, 0, 0, 156, 66, 0, 0,
158, 66, 0, 0, 160, 66, 0, 0, 162, 66, 0, 0, 164, 66, 0, 0, 166, 66,
0, 0, 168, 66, 0, 0, 170, 66, 0, 0, 172, 66, 0, 0, 174, 66, 0, 0, 176,
66, 0, 0, 178, 66, 0, 0, 180, 66, 0, 0, 182, 66, 0, 0, 184, 66, 0, 0,
186, 66, 0, 0, 188, 66, 0, 0, 190, 66, 0, 0, 192, 66, 0, 0, 194, 66, 0,
0, 196, 66, 0, 0, 198, 66, 0, 0, 200, 66, 0, 0, 202, 66, 0, 0, 204, 66,
0, 0, 206, 66, 0, 0, 208, 66, 0, 0, 210, 66, 0, 0, 212, 66, 0, 0, 214, 66,
0, 0, 216, 66, 0, 0, 218, 66, 0, 0, 220, 66, 0, 0, 222, 66, 0, 0, 224, 66,
0, 0, 226, 66, 0, 0, 228, 66, 0, 0, 230, 66, 0, 0, 232, 66, 0, 0, 234, 66,
0, 0, 236, 66, 0, 0, 238, 66, 0, 0, 240, 66, 0, 0, 242, 66, 0, 0, 244, 66,
0, 0, 246, 66, 0, 0, 248, 66, 0, 0, 250, 66, 0, 0, 252, 66, 0, 0, 254, 66,
0, 0, 0, 67, 0, 0, 1, 67, 0, 0, 2, 67, 0, 0, 3, 67, 0, 0, 4, 67, 0, 0, 5, 67,
0, 0, 6, 67, 0, 0, 7, 67, 0, 0, 8, 67, 0, 0, 9, 67, 0, 0, 10, 67, 0, 0, 11,
67, 0, 0, 12, 67, 0, 0, 13, 67, 0, 0, 14, 67, 0, 0, 15, 67, 0, 0, 16, 67,
0, 0, 17, 67, 0, 0, 18, 67, 0, 0, 19, 67, 0, 0, 20, 67, 0, 0, 21, 67, 0, 0,
22, 67, 0, 0, 23, 67, 0, 0, 24, 67, 0, 0, 25, 67, 0, 0, 26, 67, 0, 0, 27,
67, 0, 0, 28, 67, 0, 0, 29, 67, 0, 0, 30, 67, 0, 0, 31, 67, 0, 0, 32, 67,
0, 0, 33, 67, 0, 0, 34, 67, 0, 0, 35, 67, 0, 0, 36, 67, 0, 0, 37, 67, 0, 0,
38, 67, 0, 0, 39, 67, 0, 0, 40, 67, 0, 0, 41, 67, 0, 0, 42, 67, 0, 0, 43, 67,
0, 0, 44, 67, 0, 0, 45, 67, 0, 0, 46, 67, 0, 0, 47, 67, 0, 0, 48, 67, 0, 0,
49, 67, 0, 0, 50, 67, 0, 0, 51, 67, 0, 0, 52, 67, 0, 0, 53, 67, 0, 0, 54, 67,
0, 0, 55, 67, 0, 0, 56, 67, 0, 0, 57, 67, 0, 0, 58, 67, 0, 0, 59, 67, 0, 0,
60, 67, 0, 0, 61, 67, 0, 0, 62, 67, 0, 0, 63, 67, 0, 0, 64, 67, 0, 0, 65, 67,
0, 0, 66, 67, 0, 0, 67, 67, 0, 0, 68, 67, 0, 0, 69, 67, 0, 0, 70, 67, 0, 0, 71,
67, 0, 0, 72, 67, 0, 0, 73, 67, 0, 0, 74, 67, 0, 0, 75, 67, 0, 0, 76, 67, 0,
0, 77, 67, 0, 0, 78, 67, 0, 0, 79, 67, 0, 0, 80, 67, 0, 0, 81, 67, 0, 0, 82,
67, 0, 0, 83, 67, 0, 0, 84, 67, 0, 0, 85, 67, 0, 0, 86, 67, 0, 0, 87, 67, 0,
0, 88, 67, 0, 0, 89, 67, 0, 0, 90, 67, 0, 0, 91, 67, 0, 0, 92, 67, 0, 0, 93,
67, 0, 0, 94, 67, 0, 0, 95, 67, 0, 0, 96, 67, 0, 0, 97, 67, 0, 0, 98, 67, 0,
0, 99, 67, 0, 0, 100, 67, 0, 0, 101, 67, 0, 0, 102, 67, 0, 0, 103, 67, 0, 0,
104, 67, 0, 0, 105, 67, 0, 0, 106, 67, 0, 0, 107, 67, 0, 0, 108, 67, 0, 0, 109,
67, 0, 0, 110, 67, 0, 0, 111, 67, 0, 0, 112, 67, 0, 0, 113, 67, 0, 0, 114, 67,
0, 0, 115, 67, 0, 0, 116, 67, 0, 0, 117, 67, 0, 0, 118, 67, 0, 0, 119, 67, 0,
0, 120, 67, 0, 0, 121, 67, 0, 0, 122, 67, 0, 0, 123, 67, 0, 0, 124, 67, 0, 0,
125, 67, 0, 0, 126, 67, 0, 0, 127, 67,
},
fsggml.TensorTypeQ4_K: {
52, 52, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 8, 15, 128,
128, 129, 129, 146, 146, 147, 147, 164, 164, 165, 165, 166, 182,
183, 183, 184, 200, 201, 201, 202, 218, 218, 219, 219, 236, 236,
237, 237, 254, 254, 255, 202, 202, 202, 203, 203, 203, 219, 219,
219, 220, 220, 220, 220, 220, 236, 237, 237, 237, 237, 237,
237, 237, 238, 254, 254, 254, 254, 254, 255, 255, 255, 255, 220,
220, 220, 220, 221, 221, 221, 221, 221, 221, 221, 237, 237, 237,
238, 238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 255, 255,
255, 255, 255, 255, 255, 237, 237, 237, 237, 237, 237, 237, 238,
238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 254, 254, 254,
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
},
fsggml.TensorTypeQ2_K: {
1, 2, 3, 3, 4, 5, 7, 7, 8, 9, 10, 11, 12, 13, 14, 15, 184, 184,
184, 185, 249, 249, 249, 249, 249, 250, 250, 254, 254, 254, 254,
255, 253, 253, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 171, 69, 0, 0,
},
fsggml.TensorTypeQ5_K: {
32, 48, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 7, 15, 254,
254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 19, 20, 37, 38, 55, 56, 73, 74,
91, 92, 109, 110, 127, 112, 128, 129, 146, 147, 164, 165, 182, 183,
200, 201, 218, 219, 236, 237, 254, 133, 133, 149, 150, 150, 150,
167, 167, 167, 168, 184, 184, 185, 185, 201, 202, 202, 202, 219,
219, 219, 219, 236, 236, 236, 237, 253, 253, 254, 254, 254, 255,
169, 169, 169, 169, 186, 186, 186, 186, 186, 187, 187, 203, 203,
203, 204, 204, 204, 220, 220, 221, 221, 221, 221, 237, 237, 238,
238, 238, 238, 254, 255, 255, 203, 203, 203, 204, 204, 204, 204,
204, 220, 220, 220, 221, 221, 221, 221, 221, 237, 237, 238, 238,
238, 238, 238, 238, 254, 255, 255, 255, 255, 255, 255, 255,
},
fsggml.TensorTypeQ6_K: {
96, 110, 92, 90, 88, 70, 68, 50, 48, 46, 44, 42, 24, 22, 4, 2, 80,
95, 78, 77, 76, 59, 58, 57, 40, 39, 38, 21, 20, 19, 2, 1, 75, 75,
74, 57, 57, 56, 55, 39, 38, 37, 21, 20, 20, 19, 2, 2, 72, 55, 55,
54, 54, 37, 37, 36, 36, 19, 19, 18, 18, 1, 1, 0, 35, 35, 35, 35,
34, 18, 18, 18, 17, 17, 17, 1, 1, 0, 0, 0, 35, 35, 34, 34, 18,
18, 18, 17, 17, 17, 17, 1, 0, 0, 0, 0, 35, 35, 35, 19, 19, 18, 18,
18, 18, 18, 1, 1, 1, 1, 1, 1, 34, 34, 18, 18, 18, 18, 17, 17, 17,
17, 1, 1, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 248, 240, 231, 224, 216, 208, 200, 192, 184, 176,
166, 160, 152, 144, 136, 128, 235, 43,
},
fsggml.TensorTypeQ3_K: {
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 20, 23, 23, 7, 7, 6, 6, 6, 2,
1, 1, 1, 1, 0, 0, 22, 22, 6, 6, 5, 5, 5, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 204, 170, 136, 102, 68,
34, 1, 5, 5, 5, 5, 189, 63,
},
}
)

View File

@@ -18,6 +18,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"slices"
"strings"
"syscall"
@@ -1169,6 +1170,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r.HandleMethodNotAllowed = true
r.Use(
cors.New(corsConfig),
allowedHostsMiddleware(s.addr),
@@ -1512,6 +1514,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
@@ -1640,3 +1643,23 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
for i, msg := range msgs {
if msg.Role == "user" {
finalUserIndex = i
}
}
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
}
}
}
return msgs
}

View File

@@ -24,7 +24,7 @@ import (
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []ggml.Tensor) (string, string) {
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
t.Helper()
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))

View File

@@ -87,7 +87,7 @@ func TestGenerateChat(t *testing.T) {
},
}
go s.sched.Run(context.TODO())
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
@@ -99,7 +99,7 @@ func TestGenerateChat(t *testing.T) {
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []ggml.Tensor{
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
@@ -158,7 +158,7 @@ func TestGenerateChat(t *testing.T) {
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []ggml.Tensor{})
}, []*ggml.Tensor{})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Files: map[string]string{"bert.gguf": digest},
@@ -299,9 +299,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {
@@ -324,9 +321,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {
@@ -350,9 +344,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {
@@ -640,7 +631,7 @@ func TestGenerate(t *testing.T) {
},
}
go s.sched.Run(context.TODO())
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
@@ -652,7 +643,7 @@ func TestGenerate(t *testing.T) {
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []ggml.Tensor{
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
@@ -707,7 +698,7 @@ func TestGenerate(t *testing.T) {
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []ggml.Tensor{})
}, []*ggml.Tensor{})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",

View File

@@ -15,6 +15,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"testing"
@@ -473,14 +474,24 @@ func TestRoutes(t *testing.T) {
t.Fatalf("failed to read response body: %v", err)
}
var retrieveResp api.RetrieveModelResponse
err = json.Unmarshal(body, &retrieveResp)
var m openai.Model
err = json.Unmarshal(body, &m)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
if m.Id != "show-model" || m.OwnedBy != "library" {
t.Errorf("expected model 'show-model' owned by 'library', got %v", m)
}
},
},
{
Name: "Method Not Allowed",
Method: http.MethodGet,
Path: "/api/show",
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != 405 {
t.Errorf("expected status code 405, got %d", resp.StatusCode)
}
},
},
@@ -516,7 +527,7 @@ func TestRoutes(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
req, err := http.NewRequestWithContext(t.Context(), tc.Method, u, nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
@@ -746,3 +757,128 @@ func TestNormalize(t *testing.T) {
})
}
}
func TestFilterThinkTags(t *testing.T) {
type testCase struct {
msgs []api.Message
want []api.Message
model *Model
}
testCases := []testCase{
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// with newlines inside the think tag aned newlines after
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// should leave thinking tags if it's after the last user message
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking...</think>after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
{
// shouldn't strip anything because the model family isn't one of the hardcoded ones
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "llama3",
},
},
},
{
// deepseek-r1:-prefixed model
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Name: "registry.ollama.ai/library/deepseek-r1:latest",
ShortName: "deepseek-r1:7b",
Config: ConfigV2{},
},
},
}
for i, tc := range testCases {
filtered := filterThinkTags(tc.msgs, tc.model)
if !reflect.DeepEqual(filtered, tc.want) {
t.Errorf("messages differ for case %d:", i)
for i := range tc.want {
if i >= len(filtered) {
t.Errorf(" missing message %d: %+v", i, tc.want[i])
continue
}
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
}
}
if len(filtered) > len(tc.want) {
for i := len(tc.want); i < len(filtered); i++ {
t.Errorf(" extra message %d: %+v", i, filtered[i])
}
}
}
}
}

View File

@@ -81,6 +81,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
req := &LlmRequest{
ctx: c,
model: model,
@@ -110,11 +114,6 @@ func (s *Scheduler) Run(ctx context.Context) {
}()
}
const (
defaultContextLength = 4096
smallGpuContextLength = 2048
)
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
@@ -148,6 +147,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
slog.Debug("reloading", "runner", runner)
runnerToExpire = runner
} else {
// Runner is usable, return it
@@ -167,17 +167,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
gpus = s.getGpuFn()
}
if pending.origNumCtx == -1 {
if len(gpus) == 1 && gpus[0].Library != "cpu" && gpus[0].TotalMemory <= 4096*1024*1024 {
slog.Info("GPU is small, limiting default context window", "num_ctx", smallGpuContextLength)
pending.opts.NumCtx = smallGpuContextLength
pending.origNumCtx = smallGpuContextLength
} else {
pending.opts.NumCtx = defaultContextLength
pending.origNumCtx = defaultContextLength
}
}
if envconfig.MaxRunners() <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
@@ -294,7 +283,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "modelPath", runnerToExpire.modelPath, "refCount", runnerToExpire.refCount)
slog.Debug("resetting model to expire immediately to make room", "runner", runnerToExpire, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
@@ -307,13 +296,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "modelPath", runnerToExpire.modelPath)
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "modelPath", runnerToExpire.modelPath)
slog.Debug("unload completed", "runner", runnerToExpire)
continue
}
}
@@ -343,16 +332,16 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "modelPath", runner.modelPath)
slog.Debug("runner with zero duration has gone idle, expiring to unload", "runner", runner)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
slog.Debug("runner with non-zero duration has gone idle, adding timer", "runner", runner, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "modelPath", runner.modelPath)
slog.Debug("timer expired, expiring to unload", "runner", runner)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
@@ -363,18 +352,18 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
})
runner.expiresAt = time.Now().Add(runner.sessionDuration)
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "runner", runner, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
runner.expiresAt = time.Now().Add(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "modelPath", runner.modelPath, "refCount", runner.refCount)
slog.Debug("after processing request finished event", "runner", runner, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "modelPath", runner.modelPath)
slog.Debug("runner expired event received", "runner", runner)
runner.refMu.Lock()
if runner.refCount > 0 {
slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
slog.Debug("expired event with positive ref count, retrying", "runner", runner, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
@@ -386,17 +375,29 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
s.loadedMu.Lock()
slog.Debug("got lock to unload", "modelPath", runner.modelPath)
finished := runner.waitForVRAMRecovery()
runner.unload()
delete(s.loaded, runner.modelPath)
s.loadedMu.Unlock()
slog.Debug("runner released", "modelPath", runner.modelPath)
runner.refMu.Unlock()
<-finished
slog.Debug("sending an unloaded event", "modelPath", runner.modelPath)
s.unloadedCh <- struct{}{}
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
// request is canceled and we're trying to load another model
// that requires this one to be evicted, or the settings change
// and require a reload
s.loadedMu.Unlock()
runner.refMu.Unlock()
slog.Debug("duplicate expired event, ignoring", "runner", runner)
} else {
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
finished := runner.waitForVRAMRecovery()
runner.unload()
delete(s.loaded, runner.modelPath)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "runner", runner)
s.unloadedCh <- struct{}{}
}
}
}
}
@@ -418,7 +419,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
slog.Debug("context for request finished", "runner", runner)
finished <- pending
}()
}
@@ -453,12 +454,19 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),
loading: true,
refCount: 1,
pid: llama.Pid(),
}
runner.numParallel = numParallel
runner.refMu.Lock()
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
oldRunner.unload()
oldRunner.refMu.Unlock()
}
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -467,13 +475,16 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
slog.Debug("triggering expiration for failed load", "runner", runner)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
slog.Debug("finished setting up", "runner", runner)
if runner.pid < 0 {
runner.pid = llama.Pid()
}
runner.refCount++
runner.loading = false
go func() {
<-req.ctx.Done()
@@ -491,7 +502,12 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
runners := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runners = append(runners, r)
}
s.loadedMu.Unlock()
for _, r := range runners {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
@@ -502,7 +518,6 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
@@ -549,12 +564,11 @@ func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList)
// TODO consolidate sched_types.go
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refMu sync.Mutex
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
pid int
loading bool // True only during initial load, then false forever
gpus discover.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
@@ -639,6 +653,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
finished <- struct{}{}
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
return finished
}
start := time.Now()
@@ -657,7 +672,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
for {
<-ticker.C
if time.Now().After(expiresAt) {
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "model", runner.modelPath)
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
finished <- struct{}{}
}
@@ -670,7 +685,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
}
// If we're within ~80% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "model", runner.modelPath)
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
finished <- struct{}{}
return
}
@@ -679,6 +694,33 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
return finished
}
func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
}
if len(runner.gpus) > 0 {
attrs = append(attrs,
slog.String("inference", runner.gpus[0].Library),
slog.Int("devices", len(runner.gpus)),
)
}
attrs = append(attrs,
slog.String("size", format.HumanBytes2(runner.estimatedTotal)),
slog.String("vram", format.HumanBytes2(runner.estimatedVRAM)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
}
return slog.GroupValue(attrs...)
}
type ByDurationAndName []*runnerRef
func (a ByDurationAndName) Len() int { return len(a) }
@@ -801,12 +843,12 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
slog.Debug("found an idle runner to unload", "runner", runner)
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
slog.Debug("no idle runners, picking the shortest duration", "runner_count", len(runnerList), "runner", runnerList[0])
return runnerList[0]
}
@@ -823,8 +865,8 @@ func (s *Scheduler) unloadAllRunners() {
func (s *Scheduler) expireRunner(model *Model) {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
runner, ok := s.loaded[model.ModelPath]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()
runner.expiresAt = time.Now()

View File

@@ -26,7 +26,7 @@ func TestMain(m *testing.M) {
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
ctx, done := context.WithCancel(t.Context())
defer done()
s := InitScheduler(ctx)
s.loadedMu.Lock()
@@ -35,7 +35,7 @@ func TestInitScheduler(t *testing.T) {
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()
s := InitScheduler(ctx)
var f *ggml.GGML // value not used in tests
@@ -126,7 +126,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []ggml.Tensor{
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}))
@@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
b.req.opts.NumCtx = 4096
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b
}
@@ -168,7 +167,7 @@ func getCpuFn() discover.GpuInfoList {
}
func TestRequestsSameModelSameRequest(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
@@ -211,7 +210,7 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
@@ -259,7 +258,7 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
@@ -356,7 +355,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 3*time.Second)
ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
defer done()
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
@@ -409,7 +408,7 @@ func TestGetRunner(t *testing.T) {
}
func TestExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()
s := InitScheduler(ctx)
req := &LlmRequest{
@@ -456,7 +455,7 @@ func TestExpireRunner(t *testing.T) {
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
// Same model, same request
@@ -503,7 +502,7 @@ func TestPrematureExpired(t *testing.T) {
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
@@ -530,7 +529,7 @@ func TestUseLoadedRunner(t *testing.T) {
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
gpus := discover.GpuInfoList{
{
@@ -563,7 +562,7 @@ func TestUpdateFreeSpace(t *testing.T) {
}
func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
gpus := discover.GpuInfoList{
{
@@ -597,7 +596,7 @@ func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
@@ -617,7 +616,7 @@ func TestFindRunnerToUnload(t *testing.T) {
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@@ -664,7 +663,7 @@ func TestNeedsReload(t *testing.T) {
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@@ -696,7 +695,7 @@ func TestUnload(t *testing.T) {
}
func TestAlreadyCanceled(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
dctx, done2 := context.WithCancel(ctx)
done2()
@@ -713,7 +712,7 @@ func TestAlreadyCanceled(t *testing.T) {
}
func TestHomogeneousGPUs(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
@@ -793,3 +792,4 @@ func (s *mockLlm) Close() error {
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }
func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] }
func (s *mockLlm) Pid() int { return -1 }