Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-03-09 13:49:03 +08:00
committed by GitHub
37 changed files with 1869 additions and 790 deletions

View File

@@ -76,6 +76,7 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- | | ------------------ | ---------- | ----- | -------------------------------- |
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | | Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |

View File

@@ -361,9 +361,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull]. // PullRequest is the request passed to [Client.Pull].
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
Username string `json:"username"` Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` Password string `json:"password"` // Deprecated: ignored
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead

View File

@@ -81,9 +81,11 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI `ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia and AMD. This allows for embedding and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
Ollama in existing applications, or running it as a system service via `ollama and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
serve` with tools such as [NSSM](https://nssm.cc/). same directory. This allows for embedding Ollama in existing applications, or
running it as a system service via `ollama serve` with tools such as
[NSSM](https://nssm.cc/).
> [!NOTE] > [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first. > If you are upgrading from a prior version, you should remove the old directories first.

3
go.mod
View File

@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
gonum.org/v1/gonum v0.15.0 golang.org/x/tools v0.30.0
) )
require ( require (
@@ -44,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect
) )

2
go.sum
View File

@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
Capacity int32 Capacity int32
causal bool
windowSize int32 windowSize int32
// config controls mostly backend-specific optimizations // config controls mostly backend-specific optimizations
@@ -42,6 +43,12 @@ type Causal struct {
// locations in the cache that are needed for this batch // locations in the cache that are needed for this batch
curCellRange cellRange curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata ** // ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences // for each possible location in the cache, stores the position and set of sequences
@@ -55,8 +62,8 @@ type Causal struct {
shiftFn shiftFn shiftFn shiftFn
backend ml.Backend backend ml.Backend
cacheCtx ml.Context ctxs map[int]ml.Context
keys, values []ml.Tensor keys, values map[int]ml.Tensor
} }
type cacheCell struct { type cacheCell struct {
@@ -70,11 +77,25 @@ type cellRange struct {
} }
func NewCausalCache(shift shiftFn) *Causal { func NewCausalCache(shift shiftFn) *Causal {
return &Causal{windowSize: math.MaxInt32, shiftFn: shift} return &Causal{
causal: true,
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func NewSWACache(windowSize int32, shift shiftFn) *Causal { func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{windowSize: windowSize, shiftFn: shift} return &Causal{
causal: true,
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -103,7 +124,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.cells = make([]cacheCell, c.Capacity) c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
c.cacheCtx = backend.NewContext()
} }
func (c *Causal) SetConfig(config ml.CacheConfig) { func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -115,11 +135,15 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
} }
func (c *Causal) Close() { func (c *Causal) Close() {
c.cacheCtx.Close() for _, ctx := range c.ctxs {
ctx.Close()
}
} }
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(positions) c.curBatchSize = len(positions)
c.curSequences = seqs
c.curPositions = positions
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
@@ -158,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
c.cellRanges[seq] = seqRange c.cellRanges[seq] = seqRange
} }
c.curMask, err = c.buildMask(ctx, positions, seqs) c.curMask, err = c.buildMask(ctx)
return err return err
} }
@@ -199,7 +223,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend // Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@@ -211,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
for i := range c.curBatchSize { for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
c.cells[j].pos < positions[i]-c.windowSize { (c.causal && c.cells[j].pos > c.curPositions[i]) ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
} }
@@ -224,13 +249,13 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
mask[i] = float32(math.Inf(-1)) mask[i] = float32(math.Inf(-1))
} }
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize) maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if c.config.MaskDType != ml.DTypeF32 { if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...) out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out)) ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out maskTensor = out
} }
@@ -239,13 +264,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i := range c.keys { for i, key := range c.keys {
if c.keys[i] == nil { if key == nil {
continue continue
} }
key := c.keys[i]
kHeadDim := key.Dim(0) kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1) numKVHeads := key.Dim(1)
rowSize := key.Stride(2) rowSize := key.Stride(2)
@@ -305,7 +328,7 @@ func (c *Causal) defrag() {
layers++ layers++
} }
maxMoves := ctx.MaxTensors() / (6 * layers) maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0 moves := 0
var pendingSrc, pendingDst, pendingLen int var pendingSrc, pendingDst, pendingLen int
@@ -377,14 +400,29 @@ func (c *Causal) defrag() {
} }
func (c *Causal) SetLayer(layer int) { func (c *Causal) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer c.curLayer = layer
} }
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// This state carries over to future forward passes. The default value is true.
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
if c.causal != causal {
c.causal = causal
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
}
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer] key := c.keys[c.curLayer]
value := c.values[c.curLayer] value := c.values[c.curLayer]
@@ -433,13 +471,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
} }
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if _, ok := c.ctxs[c.curLayer]; !ok {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV { if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else { } else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
} }
} }
@@ -501,7 +545,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift, err := ctx.FromIntSlice(offsets, len(offsets)) kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -303,6 +303,10 @@ func (b *testBackend) NewContext() ml.Context {
return &testContext{} return &testContext{}
} }
func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) SystemInfo() string { func (b *testBackend) SystemInfo() string {
return "not implemented" return "not implemented"
} }
@@ -346,11 +350,15 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil return out, nil
} }
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) MaxTensors() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10
} }

View File

@@ -35,13 +35,17 @@ type EncoderCache struct {
encoderPos int32 encoderPos int32
// ** cache data storage ** // ** cache data storage **
backend ml.Backend
cacheCtx ml.Context ctxs map[int]ml.Context
keys, values []ml.Tensor keys, values map[int]ml.Tensor
} }
func NewEncoderCache() *EncoderCache { func NewEncoderCache() *EncoderCache {
return &EncoderCache{} return &EncoderCache{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
} }
c.cacheCtx = backend.NewContext() c.backend = backend
} }
func (c *EncoderCache) SetConfig(config ml.CacheConfig) { func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
@@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
} }
func (c *EncoderCache) Close() { func (c *EncoderCache) Close() {
c.cacheCtx.Close() for _, ctx := range c.ctxs {
ctx.Close()
}
} }
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
@@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in
} }
func (c *EncoderCache) SetLayer(layer int) { func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer c.curLayer = layer
} }
@@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
value = value.Permute(ctx, 1, 2, 0, 3) value = value.Permute(ctx, 1, 2, 0, 3)
} }
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if _, ok := c.ctxs[c.curLayer]; !ok {
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...) c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...) }
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
} }
ctx.Forward( ctx.Forward(

View File

@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) { if (precompiled_charsmap_keyidx != -1) {
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN #ifdef IS_BIG_ENDIAN

View File

@@ -0,0 +1,64 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Wed, 5 Mar 2025 17:41:07 -0800
Subject: [PATCH] fix string arr kv loading
---
ggml/include/gguf.h | 1 +
ggml/src/gguf.cpp | 7 +++++--
src/llama-vocab.cpp | 2 +-
3 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
index 79ee2020..3efb22f0 100644
--- a/ggml/include/gguf.h
+++ b/ggml/include/gguf.h
@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
index ab13669c..f75b923f 100644
--- a/ggml/src/gguf.cpp
+++ b/ggml/src/gguf.cpp
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].data.size();
+}
+
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index c7ff28be..7a185443 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN

View File

@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
return s.llamaModel.Tokenize(content, false, true) return s.llamaModel.Tokenize(content, false, true)
} }
if s.textProcessor != nil { if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content) tokens, err := s.textProcessor.Encode(content, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -24,6 +24,7 @@ type Backend interface {
Config() Config Config() Config
Get(name string) Tensor Get(name string) Tensor
NewContext() Context NewContext() Context
NewContextSize(size int) Context
} }
// BackendCacheConfig should be implemented by backends that need special output // BackendCacheConfig should be implemented by backends that need special output
@@ -99,8 +100,17 @@ type Context interface {
Forward(...Tensor) Context Forward(...Tensor) Context
Compute(...Tensor) Compute(...Tensor)
MaxTensors() int MaxGraphNodes() int
Close() Close()
// Input returns a context appropriate for creating input tensors
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
} }
type Tensor interface { type Tensor interface {
@@ -205,7 +215,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
}) })
case DTypeF16: case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32) f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
@@ -273,5 +283,7 @@ const (
DTypeOther DType = iota DTypeOther DType = iota
DTypeF32 DTypeF32
DTypeF16 DTypeF16
DTypeQ80
DTypeQ40
DTypeI32 DTypeI32
) )

View File

@@ -9,67 +9,53 @@ package ggml
import "C" import "C"
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps"
"os" "os"
"sync" "slices"
"strconv"
"strings"
"unicode"
"unsafe" "unsafe"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml" fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"golang.org/x/sync/errgroup"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup"
) )
type device struct { func devices() []*C.struct_ggml_backend_device {
d *C.struct_ggml_backend_device
}
func (d device) LogValue() slog.Value {
var free, total uint64
C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
kind := "unknown"
switch C.ggml_backend_dev_type(d.d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
kind = "cpu"
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
kind = "gpu"
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
kind = "accel"
}
return slog.GroupValue(
slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
slog.String("kind", kind),
slog.String("free", format.HumanBytes2(free)),
slog.String("total", format.HumanBytes2(total)),
)
}
var devices = sync.OnceValue(func() []device {
ggml.OnceLoad() ggml.OnceLoad()
ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
s := make([]device, C.ggml_backend_dev_count()) for i := range ds {
for i := range s { ds[i] = C.ggml_backend_dev_get(C.size_t(i))
s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
} }
return s return ds
}) }
type Backend struct { type Backend struct {
meta *fs.GGML
sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor
// input is the backend used for inputs
input *C.struct_ggml_backend_buffer_type
// output is the backend used for outputs
output *C.struct_ggml_backend_buffer_type
// layers is the backend used for repeating layers
layers map[int]*C.struct_ggml_backend_buffer_type
flashAttention bool flashAttention bool
meta *fs.GGML // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
cpus, gpus []Context maxGraphNodes int
tensors map[string]*Context
sched *C.struct_ggml_backend_sched
} }
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
@@ -88,107 +74,310 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
"num_key_values", len(meta.KV()), "num_key_values", len(meta.KV()),
) )
var cpus, gpus []Context type deviceBufferType struct {
d *C.struct_ggml_backend_device
bts []*C.struct_ggml_backend_buffer_type
}
var cpus, accels, gpus []*C.struct_ggml_backend_device
for _, d := range devices() { for _, d := range devices() {
switch C.ggml_backend_dev_type(d.d) { switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
cpus = append(cpus, d)
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
}
// create list of buffer types for the cpu
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
for _, d := range append(accels, append(gpus, cpus...)...) {
switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU, case C.GGML_BACKEND_DEVICE_TYPE_CPU,
C.GGML_BACKEND_DEVICE_TYPE_ACCEL: C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
slog.Info("cpu", "device", d) cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
cpus = append(cpus, Context{
ctx: C.ggml_init(C.struct_ggml_init_params{
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
no_alloc: true,
}),
backend: C.ggml_backend_dev_init(d.d, nil),
})
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
slog.Info("gpu", "device", d)
gpus = append(gpus, Context{
ctx: C.ggml_init(C.struct_ggml_init_params{
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
no_alloc: true,
}),
backend: C.ggml_backend_dev_init(d.d, nil),
})
} }
} }
ctxFunc := func(s []Context) (*Context, error) { // create list of buffer types for each gpu
for _, e := range s { var gpuDeviceBufferTypes []deviceBufferType
return &e, nil for _, d := range gpus {
} bt := C.ggml_backend_dev_buffer_type(d)
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
return nil, fmt.Errorf("no devices available") d: d,
} bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
for _, t := range meta.Tensors().Items() {
c, err := ctxFunc(append(gpus, cpus...))
if err != nil {
return nil, err
}
func() {
tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
cname := C.CString(t.Name)
defer C.free(unsafe.Pointer(cname))
C.ggml_set_name(tt, cname)
tensors[t] = c
}()
}
for _, b := range append(gpus, cpus...) {
C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
}
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
for t, c := range tensors {
g.Go(func() error {
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
if err != nil {
return err
}
if n != int(t.Size()) {
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
}
cname := C.CString(t.Name)
defer C.free(unsafe.Pointer(cname))
C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
return nil
}) })
} }
if err := g.Wait(); err != nil { useDefaultSplit := true
for _, s := range params.TensorSplit {
if s != 0 {
useDefaultSplit = false
break
}
}
// calculate splits
splits := make([]float32, len(gpus))
if useDefaultSplit {
// default: split on free memory
for i := range splits {
var free, total C.size_t
C.ggml_backend_dev_memory(gpus[i], &free, &total)
splits[i] = float32(free)
}
} else {
splits = params.TensorSplit
}
var sum float32
// cumulative sum of all splits
for i := range splits {
sum += splits[i]
splits[i] = sum
}
// normalize splits
for i := range splits {
splits[i] /= sum
}
// inputs always use cpu
input := cpuDeviceBufferType
blocks := int(meta.KV().BlockCount())
// define a range of gpu layers. anything outside of this range is assigned to the cpu
gpuRangeStart := max(0, blocks-params.NumGPULayers)
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
assignLayer := func(i int) deviceBufferType {
if i < gpuRangeStart || i >= gpuRangeStop {
return cpuDeviceBufferType
}
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
if index < 0 || index >= len(gpuDeviceBufferTypes) {
return cpuDeviceBufferType
}
return gpuDeviceBufferTypes[index]
}
// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
layers := make([]deviceBufferType, blocks)
for i := range layers {
layers[i] = assignLayer(i)
}
// outputs are assigned iff allowed by splits and configured number of gpu layers
output := assignLayer(blocks)
maxTensors := len(meta.Tensors().Items())
maxTensors += 1
// each layer has at most 2 extra tensors for rope operations
maxTensors += blocks * 2
type tensor struct {
source *fs.Tensor
target string
}
// some tensors are mapped to different names so keep a list
targets := make(map[string][]string)
// contexts are shared by tensors of the same buffer type
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
for _, bt := range bts {
if _, ok := ctxs[bt]; !ok {
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
no_alloc: true,
})
}
targets[t.source.Name] = append(targets[t.source.Name], t.target)
name := t.source.Name
if t.target != "" {
name = t.target
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
return tt
}
tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
C.ggml_set_name(tt, cname)
slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
//nolint:staticcheck // TODO: check if buffer type supports this tensor
return tt
}
return nil
}
contains := func(s string, parts ...string) bool {
split := strings.Split(s, ".")
for _, part := range parts {
if slices.Contains(split, part) {
return true
}
}
return false
}
for _, t := range meta.Tensors().Items() {
switch {
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
createTensor(tensor{source: t}, input.bts)
case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, input.bts)
default:
layerIndex := -1
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
if i, err := strconv.Atoi(fields[0]); err == nil {
layerIndex = i
}
}
if layerIndex >= 0 {
createTensor(tensor{source: t}, layers[layerIndex].bts)
} else {
// this is a repeating tensor that doesn't explicitly associated with a layer so
// duplicate it for each layer
for i, layer := range layers {
createTensor(tensor{
source: t,
target: "blk." + strconv.Itoa(i) + "." + t.Name,
}, layer.bts)
}
}
}
}
// allocate buffers for each context
bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
for bt, c := range ctxs {
if C.ggml_get_first_tensor(c) == nil {
continue
}
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b
}
for bs := range maps.Values(bbs) {
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
// map tensor names to tensors for easy lookup later
tensors := make(map[string]*C.struct_ggml_tensor)
for _, c := range ctxs {
for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
tensors[C.GoString(C.ggml_get_name(t))] = t
}
}
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] {
g.Go(func() error {
if target == "" {
target = t.Name
}
tt, ok := tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
if err != nil {
return err
}
if n != len(bts) {
return errors.New("short read")
}
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil
})
}
}
if g.Wait() != nil {
return nil, err return nil, err
} }
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus)) // map devices to backend buffer types so new tensors can be assigned to the correct device
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus)) deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
for i, c := range append(gpus, cpus...) {
backends[i] = c.backend // create backends and buffer types used for the compute graph scheduler
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) var schedBackends []*C.struct_ggml_backend
var schedBufts []*C.struct_ggml_backend_buffer_type
for _, d := range append(gpus, append(accels, cpus...)...) {
b := C.ggml_backend_dev_init(d, nil)
bt := C.ggml_backend_get_default_buffer_type(b)
if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
// use the first gpu host buffer type for gpu if possible
if hbt := C.ggml_backend_dev_host_buffer_type(gpus[0]); hbt != nil {
bt = hbt
}
}
deviceBufferTypes[d] = bt
schedBackends = append(schedBackends, b)
schedBufts = append(schedBufts, bt)
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
if C.ggml_backend_is_cpu(b) {
// set number of threads for cpu backend
C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
}
} }
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
return &Backend{ return &Backend{
flashAttention: params.FlashAttention, flashAttention: params.FlashAttention,
meta: meta, meta: meta,
cpus: cpus, tensors: tensors,
gpus: gpus,
sched: C.ggml_backend_sched_new( sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(backends)), C.int(len(schedBackends)),
C.size_t(max(8192, len(meta.Tensors().Items())*5)), C.size_t(maxGraphNodes),
true, true,
), ),
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
m := make(map[int]*C.struct_ggml_backend_buffer_type)
for i, layer := range layers {
m[i] = deviceBufferTypes[layer.d]
}
return m
}(),
maxGraphNodes: maxGraphNodes,
}, nil }, nil
} }
@@ -201,36 +390,29 @@ func (b *Backend) Config() ml.Config {
} }
func (b *Backend) Get(name string) ml.Tensor { func (b *Backend) Get(name string) ml.Tensor {
cname := C.CString(name) if t, ok := b.tensors[name]; ok {
defer C.free(unsafe.Pointer(cname)) return &Tensor{b: b, t: t}
for _, c := range append(b.gpus, b.cpus...) {
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
return &Tensor{b: b, t: t}
}
} }
return nil return nil
} }
func (b *Backend) NewContext() ml.Context { func (b *Backend) NewContext() ml.Context {
nodes := max(8192, len(b.meta.Tensors().Items())*5) return b.NewContextSize(b.maxGraphNodes)
c := C.ggml_init(C.struct_ggml_init_params{ }
mem_buffer: nil,
mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
no_alloc: true,
})
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus)) func (b *Backend) NewContextSize(n int) ml.Context {
for i, c := range append(b.gpus, b.cpus...) { if n > b.maxGraphNodes {
backends[i] = c.backend panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
} }
return &Context{ return &Context{
b: b, b: b,
ctx: c, maxGraphNodes: n,
backend: backends[0], ctx: C.ggml_init(C.struct_ggml_init_params{
nodes: nodes, mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
no_alloc: true,
}),
} }
} }
@@ -243,17 +425,60 @@ func (b *Backend) CacheConfig() ml.CacheConfig {
} }
type Context struct { type Context struct {
b *Backend b *Backend
ctx *C.struct_ggml_context
backend *C.struct_ggml_backend
ctx *C.struct_ggml_context
graph *C.struct_ggml_cgraph graph *C.struct_ggml_cgraph
nodes int
// buft is the buffer type used for new tensors
buft *C.struct_ggml_backend_buffer_type
// maxGraphNodes is the maximum allowed number of graph nodes in this context
maxGraphNodes int
}
func (c Context) Input() ml.Context {
if c.b.input != nil {
return &Context{
b: c.b,
ctx: c.ctx,
buft: c.b.input,
maxGraphNodes: c.maxGraphNodes,
}
}
return &c
}
func (c Context) Output() ml.Context {
if c.b.output != nil {
return &Context{
b: c.b,
ctx: c.ctx,
buft: c.b.output,
maxGraphNodes: c.maxGraphNodes,
}
}
return &c
}
func (c Context) Layer(i int) ml.Context {
if buft, ok := c.b.layers[i]; ok {
return &Context{
b: c.b,
ctx: c.ctx,
buft: buft,
maxGraphNodes: c.maxGraphNodes,
}
}
return &c
} }
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil { if c.graph == nil {
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false) c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
} }
for _, tensor := range tensors { for _, tensor := range tensors {
@@ -263,7 +488,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
return c return c
} }
func (c *Context) Compute(tensors ...ml.Tensor) { func (c Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched) C.ggml_backend_sched_reset(c.b.sched)
@@ -282,21 +507,48 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
} }
} }
func (c *Context) MaxTensors() int { func (c Context) MaxGraphNodes() int {
return c.nodes return c.maxGraphNodes
} }
func shapeToGGML(shape []int) *C.int64_t { func shapeToGGML(shape []int) *C.int64_t {
sh := make([]C.int64_t, len(shape)) sh := make([]C.int64_t, len(shape))
for i, s := range shape { for i, s := range shape {
sh[i] = (C.int64_t)(s) sh[i] = C.int64_t(s)
} }
return &sh[0] return &sh[0]
} }
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor { func pad(length, pad C.size_t) C.size_t {
if len(shape) < 1 || len(shape) > 4 { return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if c.buft == nil {
panic("set Input, Output, or Layer before creating tensors")
}
var cdtype uint32
switch dtype {
case ml.DTypeF32:
cdtype = C.GGML_TYPE_F32
case ml.DTypeF16:
cdtype = C.GGML_TYPE_F16
case ml.DTypeQ80:
cdtype = C.GGML_TYPE_Q8_0
case ml.DTypeQ40:
cdtype = C.GGML_TYPE_Q4_0
case ml.DTypeI32:
cdtype = C.GGML_TYPE_I32
default:
panic("unsupported dtype")
}
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 {
panic("unsupported number of dimensions") panic("unsupported number of dimensions")
} }
@@ -306,41 +558,28 @@ func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
} }
} }
var t *C.struct_ggml_tensor t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
switch dtype { size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
case ml.DTypeF32: b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
panic("unsupported dtype")
}
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
if zero { return &Tensor{b: c.b, t: t}
C.ggml_set_zero(t)
}
return &Tensor{b: ctx.b, t: t}
} }
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, false, shape) return c.newTensor(dtype, shape)
} }
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, true, shape) t := c.newTensor(dtype, shape)
C.ggml_set_zero(t.(*Tensor).t)
return t
} }
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s) n := len(s)
if n == 0 { if n == 0 {
var shape C.int64_t = 0 return nil
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
return &Tensor{b: ctx.b, t: t}, nil
} }
for _, v := range shape { for _, v := range shape {
@@ -348,22 +587,36 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
} }
if n != 1 { if n != 1 {
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s)) return fmt.Errorf("invalid shape: %v", shape)
} }
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape)) return nil
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
return &Tensor{b: ctx.b, t: t}, nil
} }
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return fromSlice(c, s, shape, C.GGML_TYPE_F32) if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
} }
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return fromSlice(c, s, shape, C.GGML_TYPE_I32) if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
} }
func (c *Context) Close() { func (c *Context) Close() {
@@ -431,6 +684,10 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeF32 return ml.DTypeF32
case C.GGML_TYPE_F16: case C.GGML_TYPE_F16:
return ml.DTypeF16 return ml.DTypeF16
case C.GGML_TYPE_Q8_0:
return ml.DTypeQ80
case C.GGML_TYPE_Q4_0:
return ml.DTypeQ40
case C.GGML_TYPE_I32: case C.GGML_TYPE_I32:
return ml.DTypeI32 return ml.DTypeI32
default: default:

View File

@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id // get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id // get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);

View File

@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data(); return ctx->kv[key_id].data.data();
} }
size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].data.size();
}
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data(); return ctx->kv[key_id].data.data();
} }

View File

@@ -3,7 +3,6 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"image"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog" "log/slog"
@@ -22,14 +21,40 @@ import (
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
) )
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass // Options contains the inputs for a model forward pass
type Options struct { type Options struct {
Inputs []int32 Inputs []int32
Positions []int32 Multimodal []MultimodalIndex
Sequences []int Positions []int32
Outputs []int32 Sequences []int
Outputs []int32
Images []image.Image
} }
type config struct { type config struct {
@@ -59,6 +84,37 @@ type Model interface {
Config() config Config() config
} }
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []Input) ([]Input, error)
}
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture // Register registers a model constructor for the given architecture

View File

@@ -12,7 +12,6 @@ import (
) )
type Options struct { type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
@@ -66,10 +65,11 @@ func New(c ml.Config) (model.Model, error) {
} }
type SelfAttention struct { type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
@@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -95,7 +95,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
type MLP struct { type MLP struct {
@@ -138,17 +138,17 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,7 +1,12 @@
package mllama package mllama
import ( import (
"bytes"
"encoding/binary"
"fmt" "fmt"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@@ -56,54 +61,92 @@ func New(c ml.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
var images []model.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if opts.Images != nil { if opts.Multimodal != nil {
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0]) crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
if err != nil {
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
} }
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -10,10 +10,11 @@ import (
) )
type TextSelfAttention struct { type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
@@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the causal cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
return key, nil
} }
type TextMLP struct { type TextMLP struct {
@@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
} }
type TextModelOptions struct { type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32

View File

@@ -19,7 +19,7 @@ const (
) )
type TextProcessor interface { type TextProcessor interface {
Encode(string) ([]int32, error) Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error) Decode([]int32) (string, error)
Is(int32, Special) bool Is(int32, Special) bool
} }
@@ -144,7 +144,7 @@ type merge struct {
runes []rune runes []rune
} }
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() { for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently // TODO: process special tokens concurrently
@@ -177,7 +177,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
for _, frag := range fragments { for _, frag := range fragments {
if len(frag.ids) > 0 { if len(frag.ids) > 0 {
ids = append(ids, frag.ids...) ids = append(ids, frag.ids...)
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
continue continue
} }
@@ -201,7 +200,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// short circuit if the fragment is in the vocabulary // short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 { if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
continue continue
} }
@@ -275,14 +273,13 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// TODO: handle the edge case where the rune isn't in the vocabulary // TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
} }
} }
} }
} }
} }
if len(ids) > 0 { if addSpecial && len(ids) > 0 {
if bpe.vocab.AddBOS { if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS { if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
@@ -329,6 +326,5 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
} }
} }
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil return sb.String(), nil
} }

View File

@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
t.Parallel() t.Parallel()
ids, err := tokenizer.Encode("hello world") ids, err := tokenizer.Encode("hello world", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
t.Errorf("got %q, want hello world", s) t.Errorf("got %q, want hello world", s)
} }
ids, err = tokenizer.Encode("hello <|end_of_text|>") ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s) ids, err := tokenizer.Encode(s, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
} }
for _, want := range cases { for _, want := range cases {
ids, err := tokenizer.Encode(want) ids, err := tokenizer.Encode(want, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s) ids, err := tokenizer.Encode(s, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for range b.N { for range b.N {
_, err := tokenizer.Encode(string(bts)) _, err := tokenizer.Encode(string(bts), true)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}) })
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts)) ids, err := tokenizer.Encode(string(bts), true)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"reflect"
"time" "time"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots) slots := make([]InputCacheSlot, numSlots)
for i := range slots { for i := range slots {
slots[i] = InputCacheSlot{ slots[i] = InputCacheSlot{Id: i}
Id: i,
Inputs: make([]input, 0),
}
} }
cache := model.Config().Cache cache := model.Config().Cache
@@ -62,9 +58,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType { func kvCacheTypeFromStr(s string) ml.DType {
switch s { switch s {
case "q8_0": case "q8_0":
panic("kv cache quantization not yet implemented") return ml.DTypeQ80
case "q4_0": case "q4_0":
panic("kv cache quantization not yet implemented") return ml.DTypeQ40
default: default:
return ml.DTypeF16 return ml.DTypeF16
} }
@@ -83,7 +79,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []input Inputs []model.Input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@@ -92,7 +88,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest) oldestSlot.Inputs = make([]model.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []input, b []input) int32 { func countCommonPrefix(a []model.Input, b []model.Input) int32 {
var count int32 var count int32
for i := range a { for i := range a {
@@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break break
} }
if !reflect.DeepEqual(a[i], b[i]) { if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break break
} }

View File

@@ -4,6 +4,8 @@ import (
"image" "image"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/model"
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
@@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []input t1 []model.Input
t2 []input t2 []model.Input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}}, t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []input{{token: 1}}, t1: []model.Input{{Token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input{{image: imgA}}, t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input{{token: 1}, {image: imgA}}, t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input{{token: 1}, {image: imgA}, {token: 5}}, t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{
name: "Mixed, Same Length",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{ {
name: "Empty", name: "Empty",
t1: []input{}, t1: []model.Input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}}, t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []input{}, t1: []model.Input{},
t2: []input{}, t2: []model.Input{},
expected: 0, expected: 0,
}, },
} }
@@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input prompt []model.Input
longest expected longest expected
best expected best expected
}{ }{
@@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input{{token: 1}}, prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 1}, {token: 2}}, prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input{{token: 2}}, prompt: []model.Input{{Token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{}, Inputs: []model.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []input{{token: 1}}, prompt: []model.Input{{Token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 2}, {token: 3}}, prompt: []model.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input{{token: 1}, {token: 2}}, Inputs: []model.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input{{token: 1}}, Inputs: []model.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input{{token: 1}, {token: 2}}, prompt: []model.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },

View File

@@ -1,13 +1,12 @@
package ollamarunner package ollamarunner
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"image" "hash/maphash"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@@ -33,22 +32,19 @@ import (
_ "github.com/ollama/ollama/model/models" _ "github.com/ollama/ollama/model/models"
) )
// input is an element of the prompt to process, either a token or an image
type input struct {
token int32
image image.Image
}
type Sequence struct { type Sequence struct {
// ctx for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctx ml.Context
// batch index // batch index
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []input inputs []model.Input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input pendingInputs []model.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@@ -101,8 +97,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(prompt, images) inputs, err := s.inputs(ctx, prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
@@ -128,6 +125,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
return &Sequence{ return &Sequence{
ctx: ctx,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
@@ -146,28 +144,31 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
var inputs []input var inputs []model.Input
var parts []string var parts []string
var matches [][]string var matches [][]string
// TODO(jessegross): This can sometimes trigger for matching text in the multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
// user's prompt. We previously tried to avoid it by only looking for images
// on image models. We don't have a clear indication now but it would be better
// to properly escape it in any case.
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
if visionModel {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
postTokenize := false
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, input{token: t}) inputs = append(inputs, model.Input{Token: t})
} }
// image - decode and store // image - decode and store
@@ -186,12 +187,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inputs = append(inputs, input{image: image}) s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil {
return nil, err
} }
} }
@@ -238,6 +252,10 @@ type Server struct {
// next sequence for prompt processing to avoid starvation // next sequence for prompt processing to avoid starvation
nextSeq int nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
} }
func (s *Server) allNil() bool { func (s *Server) allNil() bool {
@@ -283,6 +301,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
@@ -311,7 +330,6 @@ func (s *Server) processBatch() error {
defer s.mu.Unlock() defer s.mu.Unlock()
var options model.Options var options model.Options
imgSeq := -1
seqIdx := s.nextSeq - 1 seqIdx := s.nextSeq - 1
for range s.seqs { for range s.seqs {
@@ -330,7 +348,7 @@ func (s *Server) processBatch() error {
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input{} seq.cache.Inputs = []model.Input{}
} }
for i, input := range seq.inputs { for i, input := range seq.inputs {
@@ -349,25 +367,21 @@ func (s *Server) processBatch() error {
break break
} }
// TODO(jessegross): Image inputs need to be rethought - it's // TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
// it doesn't work well for different types of models or multiple sequences // to the encoder cache.
if input.image != nil { //
if len(seq.pendingInputs) != len(options.Images) { // Break the batch when switching from text to images so that images are always at the beginning.
break if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
} (len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
s.nextSeq = seqIdx
if imgSeq != seqIdx && imgSeq != -1 { break
s.nextSeq = seqIdx }
break
} options.Inputs = append(options.Inputs, input.Token)
if input.Multimodal != nil {
imgSeq = seqIdx options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
options.Images = append(options.Images, input.image)
seq.pendingInputs = append(seq.pendingInputs, input)
continue
} }
options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id) options.Sequences = append(options.Sequences, seq.cache.Id)
@@ -403,7 +417,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache // After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input{} seq.pendingInputs = []model.Input{}
} }
// don't sample prompt processing // don't sample prompt processing
@@ -422,6 +436,7 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return // if done processing the prompt, generate an embedding and return
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support // TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "") s.removeSequence(i, "")
continue continue
} }
@@ -449,7 +464,7 @@ func (s *Server) processBatch() error {
return err return err
} }
seq.inputs = []input{{token: token}} seq.inputs = []model.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
@@ -575,11 +590,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
sampler := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
)
if req.Grammar != "" {
panic("grammars are not yet supported")
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.NumPredict,
stop: req.Stop, stop: req.Stop,
numKeep: int32(req.NumKeep), numKeep: int32(req.NumKeep),
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized sampler: sampler,
embedding: false, embedding: false,
}) })
if err != nil { if err != nil {

View File

@@ -2,76 +2,103 @@ package sample
import ( import (
"errors" "errors"
"math" "math/rand/v2"
"slices"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/sampleuv"
) )
// Sampler is not thread-safe. Each goroutine should have its own instance
type Sampler interface { type Sampler interface {
Sample([]float32) (int32, error) Sample([]float32) (int32, error)
} }
// logit represents information about a single token during sampling
type logit struct {
id int32 // The token's unique identifier
value float32 // The raw logit or probability from the model
}
type weighted struct { type weighted struct {
src rand.Source rng *rand.Rand
transforms []Transform tokens []logit
topK int
topP float32
minP float32
temperature float32
} }
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279 func (s *weighted) Sample(logits []float32) (int32, error) {
func Weighted(seed *uint64, transforms ...Transform) Sampler { if len(s.tokens) < len(logits) {
var src rand.Source s.tokens = make([]logit, len(logits))
if seed != nil {
src = rand.NewSource(*seed)
} }
return weighted{src: src, transforms: transforms}
}
func (s weighted) Sample(logits []float32) (int32, error) { tokens := s.tokens[:len(logits)]
logits64 := make([]float64, len(logits))
for i, v := range logits { for i, v := range logits {
logits64[i] = float64(v) tokens[i].id = int32(i)
tokens[i].value = v
} }
for _, t := range s.transforms { // Tokens are sorted by logits in TopK or SortTokens
logits64 = t.Apply(logits64) if s.topK > 0 {
tokens = topK(tokens, s.topK)
} else {
sortLogits(tokens)
} }
logitsCopy := make([]float64, 0, len(logits)) tokens = temperature(tokens, s.temperature)
indices := make([]int, 0, len(logits)) tokens = softmax(tokens)
for i, logit := range logits64 {
if !math.IsInf(logit, -1) { tokens = topP(tokens, s.topP)
logitsCopy = append(logitsCopy, logit) tokens = minP(tokens, s.minP)
indices = append(indices, i)
if len(tokens) == 0 {
return -1, errors.New("no valid logits found for weighted sampling")
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
} else {
r = rand.Float32()
}
// Calculate cumulative sum of probabilities
var sum float32
for i := range tokens {
sum += tokens[i].value
tokens[i].value = sum
}
r *= tokens[len(tokens)-1].value
idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
// Compare cumulative probabilities
if token.value < target {
return -1
} }
// First token that exceeds target
return 1
})
if idx >= len(tokens) {
idx = len(tokens) - 1
} }
if len(logitsCopy) == 0 { return tokens[idx].id, nil
return -1, errors.New("no valid logits found for weighed sampling")
}
probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok {
return int32(indices[idx]), nil
}
return -1, errors.New("weighted sampler failed, no valid token found")
} }
type greedy struct{} type greedy struct{}
func Greedy() Sampler { // Greedy sample returns the index of the maximum value in logits.
return greedy{}
}
// Sample returns the index of the maximum value in logits.
func (s greedy) Sample(logits []float32) (int32, error) { func (s greedy) Sample(logits []float32) (int32, error) {
if len(logits) == 0 { if len(logits) == 0 {
return -1, errors.New("no logits provided for greedy sampling") return -1, errors.New("no logits provided for greedy sampling")
} }
maxIdx := 0 maxIdx := 0
for i := range logits { maxVal := logits[0]
if logits[i] > logits[maxIdx] { for i := 1; i < len(logits); i++ {
if logits[i] > maxVal {
maxVal = logits[i]
maxIdx = i maxIdx = i
} }
} }
@@ -80,41 +107,40 @@ func (s greedy) Sample(logits []float32) (int32, error) {
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
if temperature == 0 { if temperature == 0 {
return Greedy(), nil return &greedy{}
} }
if temperature < 0 || temperature > 2 { var rng *rand.Rand
return nil, errors.New("temperature must be between 0 and 2") if seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
temperature = max(temperature, 1)
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
} }
transforms := []Transform{Temperature(temperature)} if minP < 0.0 {
minP = 0.0
if topK != 0 { }
if topK <= 0 { if minP >= 1.0 {
return nil, errors.New("topK must be greater than 0") minP = 1.0
}
transforms = append(transforms, TopK(topK))
} }
if topP != 0 { return &weighted{
if topP < 0 || topP >= 1 { rng: rng,
return nil, errors.New("topP must be between 0 and 1") topK: topK,
} topP: topP,
transforms = append(transforms, TopP(topP)) minP: minP,
temperature: temperature,
} }
if minP != 0 {
if minP < 0 || minP >= 1 {
return nil, errors.New("minP must be between 0 and 1")
}
transforms = append(transforms, MinP(minP))
}
if seed >= 0 {
seed64 := uint64(seed)
return Weighted(&seed64, transforms...), nil
}
return Weighted(nil, transforms...), nil
} }

View File

@@ -0,0 +1,104 @@
package sample
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkWeightedSampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42)
b.ResetTimer()
for b.Loop() {
_, err := sampler.Sample(logits)
if err != nil {
b.Fatalf("Sampling failed: %v", err)
}
}
})
}
configs := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
}{
{"Greedy", 0, -1, 0, 0, -1},
{"Temperature", 0.8, -1, 0, 0, -1},
{"TopK", 0.8, 50, 0, 0, -1},
{"TopP", 0.8, -1, 0.9, 0, -1},
{"MinP", 0.8, -1, 0, 0.05, -1},
{"WithSeed", 0.8, 50, 0, 0, 42},
}
// Fixed size for common vocab size
size := 128000
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
sampler.Sample(logits)
b.ResetTimer()
for b.Loop() {
_, err := sampler.Sample(logits)
if err != nil {
b.Fatalf("Sampling failed: %v", err)
}
}
})
}
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
b.ResetTimer()
for b.Loop() {
_, err := sampler.Sample(logits)
if err != nil {
b.Fatalf("Sampling failed: %v", err)
}
}
})
}
func BenchmarkGreedySampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1)
b.ResetTimer()
for b.Loop() {
_, err := sampler.Sample(logits)
if err != nil {
b.Fatalf("Sampling failed: %v", err)
}
}
})
}
}

View File

@@ -1,15 +1,14 @@
package sample package sample
import ( import (
"math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestWeighted(t *testing.T) { func TestWeighted(t *testing.T) {
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0)
got, err := sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@@ -19,64 +18,19 @@ func TestWeighted(t *testing.T) {
t.Errorf("index mismatch: want %d, got %d", want, got) t.Errorf("index mismatch: want %d, got %d", want, got)
} }
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) logits = []float32{-100, -10, 0, 10}
if err == nil { sampler = NewSampler(0, 0, 0, 0, 0)
t.Error("expected error for no valid tokens, got index", got) got, err = sampler.Sample(logits)
}
seed := uint64(42)
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
// With seed 42, we expect a consistent sample want = int32(3) // Should pick highest probability with this r value
want = int32(3) // This will be deterministic due to the seed
if want != got { if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got) t.Errorf("index mismatch: want %d, got %d", want, got)
} }
} }
type testTransform struct {
id int
callOrder *[]int
}
func (ts *testTransform) Apply(logits []float64) []float64 {
if ts.callOrder != nil {
*ts.callOrder = append(*ts.callOrder, ts.id)
}
return logits
}
func TestSample(t *testing.T) {
input := []float32{1, 2, 3, 4}
var callOrder []int
mock1 := &testTransform{
id: 1,
callOrder: &callOrder,
}
mock2 := &testTransform{
id: 2,
callOrder: &callOrder,
}
mock3 := &testTransform{
id: 3,
callOrder: &callOrder,
}
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
}
func TestNewSampler(t *testing.T) { func TestNewSampler(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -85,75 +39,41 @@ func TestNewSampler(t *testing.T) {
topP float32 topP float32
minP float32 minP float32
seed int seed int
wantErr bool wantGreedy bool // Instead of wantErr, check if we get greedy sampler
}{ }{
{
name: "no transforms",
// temperature is 0, so greedy should be used
wantErr: false,
},
{ {
name: "temperature", name: "temperature",
temperature: 0.5, temperature: 0.5,
wantErr: false, wantGreedy: false,
}, },
{ {
name: "invalid temperature negative", name: "zero temperature - greedy",
temperature: -1, temperature: 0,
wantErr: true, wantGreedy: true,
},
{
name: "invalid temperature too high",
temperature: 2.1,
wantErr: true,
}, },
{ {
name: "top k", name: "top k",
temperature: 0.1,
topK: 10, topK: 10,
temperature: 0.8, wantGreedy: false,
wantErr: false,
},
{
name: "invalid top k negative",
topK: -1,
temperature: 0.8,
wantErr: true,
}, },
{ {
name: "top p", name: "top p",
temperature: 0.1,
topP: 0.9, topP: 0.9,
temperature: 0.8, wantGreedy: false,
wantErr: false,
},
{
name: "invalid top p negative",
topP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid top p one",
topP: 1.0,
temperature: 0.8,
wantErr: true,
}, },
{ {
name: "min p", name: "min p",
temperature: 0.1,
minP: 0.2, minP: 0.2,
temperature: 0.8, wantGreedy: false,
wantErr: false,
}, },
{ {
name: "invalid min p negative", name: "seed - weighted",
minP: -0.1, temperature: 0.1,
temperature: 0.8, seed: 42,
wantErr: true, wantGreedy: false,
},
{
name: "invalid min p one",
minP: 1.0,
temperature: 0.8,
wantErr: true,
}, },
{ {
name: "default values", name: "default values",
@@ -162,16 +82,16 @@ func TestNewSampler(t *testing.T) {
topP: 0.9, topP: 0.9,
minP: 0.0, minP: 0.0,
seed: 0, seed: 0,
wantErr: false, wantGreedy: false,
}, },
{ {
name: "all zeroes", name: "all zeroes - greedy",
temperature: 0.0, temperature: 0.0,
topK: 0, topK: 0,
topP: 0.0, topP: 0.0,
minP: 0.0, minP: 0.0,
seed: 0, seed: 0,
wantErr: false, // all zeroes means no transforms wantGreedy: true,
}, },
{ {
name: "all transforms", name: "all transforms",
@@ -180,33 +100,28 @@ func TestNewSampler(t *testing.T) {
topP: 0.95, topP: 0.95,
minP: 0.1, minP: 0.1,
seed: 42, seed: 42,
wantErr: false, wantGreedy: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
if (err != nil) != tt.wantErr { _, isGreedy := sampler.(*greedy)
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr) if isGreedy != tt.wantGreedy {
t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
} }
}) })
} }
} }
func BenchmarkSample(b *testing.B) { func BenchmarkSample(b *testing.B) {
transforms := []Transform{ weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
Temperature(0.5),
TopK(10),
TopP(0.9),
MinP(0.2),
}
samplers := map[string]Sampler{ samplers := map[string]Sampler{
"Greedy": Greedy(), "Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
"Weighted": Weighted(nil, transforms...), "Weighted": weighted,
} }
// Generate random logits for benchmarking
logits := make([]float32, 1<<16) logits := make([]float32, 1<<16)
for i := range logits { for i := range logits {
logits[i] = rand.Float32() logits[i] = rand.Float32()
@@ -215,7 +130,7 @@ func BenchmarkSample(b *testing.B) {
for name, s := range samplers { for name, s := range samplers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for range b.N { for b.Loop() {
if _, err := s.Sample(logits); err != nil { if _, err := s.Sample(logits); err != nil {
b.Error(err) b.Error(err)
} }

View File

@@ -1,120 +1,203 @@
package sample package sample
import ( import (
"cmp"
"math" "math"
"slices" "slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
) )
type Transform interface { func softmax(ts []logit) []logit {
Apply([]float64) []float64 var sum float32
} for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value)))
// TODO(parthsareen): potentially cache softmax values sum += ts[i].value
func softmax(logits []float64) []float64 {
var sum float64
probs := make([]float64, len(logits))
for i, v := range logits {
probs[i] = math.Exp(v)
sum += probs[i]
} }
for i := range probs { for i := range ts {
probs[i] /= sum ts[i].value /= sum
} }
return probs return ts
} }
type Temperature float64 func temperature(ti []logit, t float32) []logit {
if t == 1 {
return ti
}
func (t Temperature) Apply(logits []float64) []float64 { temp := max(t, 1e-7)
temp := math.Max(float64(t), 1e-7) maxLogit := float32(math.Inf(-1))
for _, token := range ti {
if token.value > maxLogit {
maxLogit = token.value
}
}
// subtracting max logit to avoid under/overflow // subtracting max logit to avoid under/overflow
maxLogit := slices.Max(logits) for i := range ti {
for i := range logits { ti[i].value = (ti[i].value - maxLogit) / temp
logits[i] = (logits[i] - maxLogit) / temp
} }
return logits return ti
} }
type logitMap struct { // siftDown maintains a min-heap property by recursively moving larger elements down the heap.
index int //
logit float64 // The heap is represented as an array where for any node at index i:
} // - Left child is at index 2i + 1
// - Right child is at index 2i + 2
type TopK int // - Parent is at index (i-1)/2
//
// TODO(parthsareen): avoid having to check all logits after this transform // The function compares a node with its children and:
func (k TopK) Apply(logits []float64) []float64 { // 1. Finds the smallest value between the node and its children
if int(k) >= len(logits) { // 2. If the node is not the smallest, swaps it with its smallest child
return logits // 3. Continues this process down the affected path until the min-heap property is restored
} func siftDown(data []logit, start, end int) {
q := pq.NewWith(func(a, b logitMap) int { root := start
return -cmp.Compare(a.logit, b.logit) for {
}) child := 2*root + 1
if child >= end {
for i, logit := range logits {
q.Enqueue(logitMap{index: i, logit: logit})
}
validLogits := make(map[int]float64)
for range k {
logitMap, _ := q.Dequeue()
validLogits[logitMap.index] = logitMap.logit
}
for i := range logits {
if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
}
}
return logits
}
type TopP float64
func (p TopP) Apply(logits []float64) []float64 {
probs := softmax(logits)
indices := make([]int, len(probs))
for i := range indices {
indices[i] = i
}
// sort in descending order
slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(probs[j], probs[i])
})
var sum float64
for i, idx := range indices {
sum += probs[idx]
if sum > float64(p) {
for _, idx := range indices[i+1:] {
logits[idx] = math.Inf(-1)
}
break break
} }
// Find smaller child (we want min heap)
if child+1 < end && data[child+1].value < data[child].value {
child++
}
// Exit if root is already smaller than children
if data[root].value <= data[child].value {
break
}
// Swap with smaller child and continue
data[root], data[child] = data[child], data[root]
root = child
} }
return logits
} }
type MinP float64 // topK limits the number of tokens considered to the k highest logits
func topK(ts []logit, k int) []logit {
if k >= len(ts) {
return ts
}
// Heapify + siftDown - O(nlog(k))
// Build min-heap of first k elements
heap := ts[:k]
for i := k/2 - 1; i >= 0; i-- {
siftDown(heap, i, k)
}
func (p MinP) Apply(logits []float64) []float64 { // Process remaining elements - if larger than heap root, replace root
probs := softmax(logits) for i := k; i < len(ts); i++ {
threshold := slices.Max(probs) * float64(p) if ts[i].value > heap[0].value {
heap[0] = ts[i]
for i, prob := range probs { siftDown(heap, 0, k)
if prob < threshold {
logits[i] = math.Inf(-1)
} }
} }
return logits slices.Reverse(heap)
ts = heap
return ts
}
// topP limits tokens to those with cumulative probability p
func topP(ts []logit, p float32) []logit {
if p == 1.0 {
return ts
}
// Find cutoff index where cumulative sum exceeds p
var sum float32
for i, t := range ts {
sum += t.value
if sum > float32(p) {
ts = ts[:i+1]
return ts
}
}
return ts
}
// minP limits tokens to those with cumulative probability p
func minP(ts []logit, p float32) []logit {
if p == 1.0 {
return ts
}
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
func sortLogits(tokens []logit) {
if len(tokens) <= 1 {
return
}
// Find max/min in a single pass
minLogit, maxLogit := tokens[0].value, tokens[0].value
for _, t := range tokens[1:] {
if t.value < minLogit {
minLogit = t.value
} else if t.value > maxLogit {
maxLogit = t.value
}
}
// Calculate scaling to map to uint32 range
logitRange := maxLogit - minLogit
if logitRange < 1e-6 {
return // All values effectively equal
}
// Count frequencies directly from tokens
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
var counts [256]int // For first byte
// First pass: count frequencies
for _, t := range tokens {
// Map to [0, maxInt] range
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
counts[score>>16]++
}
// Calculate offsets
var offset int
for i := range counts {
count := counts[i]
counts[i] = offset
offset += count
}
// Second pass: place elements in correct position
output := make([]logit, len(tokens))
// Track current positions
countsCopy := counts
for i, t := range tokens {
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
pos := countsCopy[score>>16]
countsCopy[score>>16]++
output[len(tokens)-1-pos] = tokens[i]
}
copy(tokens, output)
} }

View File

@@ -4,77 +4,182 @@ import (
"math" "math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestTemperature(t *testing.T) { // Helper to convert float64 slice to logit slice
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) func toLogits(values []float64) []logit {
want := []float64{-4, -10, 0, -14, -6, -12, -8} tokens := make([]logit, len(values))
if diff := cmp.Diff(want, got); diff != "" { for i, v := range values {
t.Errorf("logits mismatch (-want +got):\n%s", diff) tokens[i] = logit{
id: int32(i),
value: float32(v),
}
}
return tokens
}
// Helper to compare logit slices
func compareLogits(t *testing.T, name string, want []float64, got []logit) {
t.Helper()
if len(want) != len(got) {
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
return
}
for i := range want {
if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
}
} }
} }
func TestSoftmax(t *testing.T) { func TestTemperature(t *testing.T) {
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) input := []float64{2, -1, 4, -3, 1, -2, 0}
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} got := temperature(toLogits(input), 0.5)
if diff := cmp.Diff(want, got); diff != "" { compareLogits(t, "Temperature", want, got)
t.Errorf("probs mismatch (-want +got):\n%s", diff) }
func TestSoftmax(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
got := softmax(toLogits(input))
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
// Check relative ordering is preserved
for i := 1; i < len(got); i++ {
if got[i].value < got[i-1].value {
t.Errorf("probability ordering not preserved at index %d", i)
}
} }
} }
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) input := []float64{-3, -2, -1, 0, 1, 2, 4}
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) // Test k=3
got := topK(toLogits(input), 3)
want = []float64{-3, -2, -1, 0, 1, 2, 4} if len(got) != 3 {
if diff := cmp.Diff(want, got); diff != "" { t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
t.Errorf("logits mismatch (-want +got):\n%s", diff)
} }
// Should keep highest 3 values: 4, 2, 1
want := []float64{4, 2, 1}
compareLogits(t, "topK(3)", want, got)
// Test k > len
got = topK(toLogits(input), 10)
compareLogits(t, "topK(10)", input, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) input := []float64{-3, -2, -1, 0, 1, 2, 4}
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} tokens := toLogits(input)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff) // First apply temperature and softmax to get probabilities
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
sortLogits(tokens)
// Then apply topP
got := topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
} }
} }
func TestMinP(t *testing.T) { func TestMinP(t *testing.T) {
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} tokens := toLogits(input)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff) // First apply temperature and softmax
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
} }
} }
func BenchmarkTransform(b *testing.B) { func TestSortLogits(t *testing.T) {
transforms := map[string]Transform{ input := []float64{3, 1, 4, 2, -1, 0, -2}
"Temperature": Temperature(0.5), tokens := toLogits(input)
"TopK": TopK(10),
"TopP": TopP(0.9), sortLogits(tokens)
"MinP": MinP(0.2),
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
}
} }
logits := make([]float64, 1<<16) want := []float64{4, 3, 2, 1, 0, -1, -2}
for i := range logits { compareLogits(t, "sortLogits", want, tokens)
logits[i] = rand.Float64() }
}
func BenchmarkTransforms(b *testing.B) {
for name, transform := range transforms { // Generate random logits
b.Run(name, func(b *testing.B) { tokens := make([]logit, 1<<16)
b.ResetTimer() for i := range tokens {
for range b.N { tokens[i] = logit{
transform.Apply(logits) id: int32(i),
} value: rand.Float32(),
}) }
} }
tokensCopy := make([]logit, len(tokens))
b.Run("Temperature", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
temperature(tokensCopy, 0.5)
}
})
b.Run("TopK", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 10)
}
})
b.Run("TopP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topP(tokensCopy, 0.9)
}
})
b.Run("MinP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
minP(tokensCopy, 0.2)
}
})
b.Run("SortTokens", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
sortLogits(tokensCopy)
}
})
} }

View File

@@ -45,9 +45,9 @@ import (
// Errors // Errors
var ( var (
// ErrManifestNotFound is returned when a manifest is not found in the // ErrModelNotFound is returned when a manifest is not found in the
// cache or registry. // cache or registry.
ErrManifestNotFound = errors.New("manifest not found") ErrModelNotFound = errors.New("model not found")
// ErrManifestInvalid is returned when a manifest found in a local or // ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid. // remote cache is invalid.
@@ -114,7 +114,18 @@ type Error struct {
} }
func (e *Error) Error() string { func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
}
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
} }
func (e *Error) LogValue() slog.Value { func (e *Error) LogValue() slog.Value {
@@ -355,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
n.Model(), n.Model(),
l.Digest, l.Digest,
) )
res, err := r.doOK(ctx, "POST", startURL, nil) res, err := r.send(ctx, "POST", startURL, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -379,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
} }
req.ContentLength = l.Size req.ContentLength = l.Size
res, err = doOK(r.client(), req) res, err = sendRequest(r.client(), req)
if err == nil { if err == nil {
res.Body.Close() res.Body.Close()
} }
@@ -399,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
n.Model(), n.Model(),
n.Tag(), n.Tag(),
) )
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil { if err == nil {
res.Body.Close() res.Body.Close()
} }
@@ -448,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t := traceFromContext(ctx) t := traceFromContext(ctx)
var g errgroup.Group g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for _, l := range m.Layers { layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
for _, l := range layers {
if exists(l) { if exists(l) {
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
continue continue
@@ -468,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if l.Size <= r.maxChunkingThreshold() { if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error { g.Go(func() error {
res, err := doOK(r.client(), req) // TODO(bmizerany): retry/backoff like below in
// the chunking case
res, err := sendRequest(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
@@ -494,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// fire an initial request to get the final URL and // fire an initial request to get the final URL and
// then use that URL for the chunk requests. // then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0") req.Header.Set("Range", "bytes=0-0")
res, err := doOK(r.client(), req) res, err := sendRequest(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
res.Body.Close() res.Body.Close()
req = res.Request.WithContext(req.Context()) req = res.Request.WithContext(req.Context())
streamNo := 0 wp := writerPool{size: r.maxChunkSize()}
tws := make([]*bufio.Writer, r.maxStreams()-1)
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take() ticket := q.Take()
bufIdx := streamNo % len(tws)
streamNo++
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
@@ -520,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil { if err != nil {
return err return err
} }
err := func() error { err := func() error {
req := req.Clone(req.Context()) req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := doOK(r.client(), req) res, err := sendRequest(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
defer res.Body.Close() defer res.Body.Close()
tw := tws[bufIdx] tw := wp.get()
if tw == nil {
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
tws[bufIdx] = tw
}
tw.Reset(ticket) tw.Reset(ticket)
defer tw.Reset(nil) // release ticket defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size()) _, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil { if err != nil {
@@ -595,6 +610,9 @@ type Manifest struct {
Name string `json:"-"` // the canonical name of the model Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"` Layers []*Layer `json:"layers"`
// For legacy reasons, we still have to download the config layer.
Config *Layer `json:"config"`
} }
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
@@ -678,7 +696,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
data, err := os.ReadFile(c.GetFile(d)) data, err := os.ReadFile(c.GetFile(d))
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
} }
return nil, err return nil, err
} }
@@ -701,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
} }
res, err := r.doOK(ctx, "GET", manifestURL, nil) res, err := r.send(ctx, "GET", manifestURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -726,7 +744,7 @@ func (r *Registry) client() *http.Client {
} }
// newRequest constructs a new request, ready to use, with the given method, // newRequest constructs a new request, ready to use, with the given method,
// url, and body, presigned with client Key and UserAgent. // url, and body, pre-signed with client [Key] and [UserAgent].
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body) req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil { if err != nil {
@@ -745,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
return req, nil return req, nil
} }
// doOK makes a request with the given client and request, and returns the // sendRequest makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error // response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it // is parsed from the response body and returned. If any other error occurs, it
// is returned. // is returned.
func doOK(c *http.Client, r *http.Request) (*http.Response, error) { func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("request error %s: %w", r.URL, err)
}
}()
if r.URL.Scheme == "https+insecure" { if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set // TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc. // InsecureSkipVerify, etc.
@@ -792,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
// Use the raw body if we can't parse it as an error object. // Use the raw body if we can't parse it as an error object.
re.Message = string(out) re.Message = string(out)
} }
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
return nil, ErrModelNotFound
}
re.Status = res.StatusCode re.Status = res.StatusCode
return nil, &re return nil, &re
} }
return res, nil return res, nil
} }
// doOK is a convenience method for making a request with newRequest and // send is a convenience method for making a request with newRequest and
// passing it to doOK with r.client(). // passing it to send with r.client().
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body) req, err := r.newRequest(ctx, method, path, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return doOK(r.client(), req) return sendRequest(r.client(), req)
} }
// makeAuthToken creates an Ollama auth token for the given private key. // makeAuthToken creates an Ollama auth token for the given private key.
@@ -960,3 +990,28 @@ func splitExtended(s string) (scheme, name, digest string) {
} }
return scheme, s, digest return scheme, s, digest
} }
type writerPool struct {
size int64 // set by the caller
mu sync.Mutex
ws []*bufio.Writer
}
func (p *writerPool) get() *bufio.Writer {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.ws) == 0 {
return bufio.NewWriterSize(nil, int(p.size))
}
w := p.ws[len(p.ws)-1]
p.ws = p.ws[:len(p.ws)-1]
return w
}
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
}

View File

@@ -608,7 +608,7 @@ func TestInsecureSkipVerify(t *testing.T) {
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url) _, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") { if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verifiction failure", err) t.Errorf("err = %v; want cert verification failure", err)
} }
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)

View File

@@ -13,9 +13,13 @@ type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to // Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads. // report the progress of blob uploads and downloads.
// //
// It is called once at the beginning of the download with a zero n and // The n argument is the number of bytes transferred so far, and err is
// then once per read operation with the number of bytes read so far, // any error that has occurred. If n == 0, and err is nil, the download
// and an error if any. // or upload has just started. If err is [ErrCached], the download or
// upload has been skipped because the blob is already present in the
// local cache or remote registry, respectively. Otherwise, if err is
// non-nil, the download or upload has failed. When l.Size == n, and
// err is nil, the download or upload has completed.
// //
// A function assigned must be safe for concurrent use. The function is // A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run. // called synchronously and so should not block or take long to run.

View File

@@ -7,10 +7,14 @@ import (
"cmp" "cmp"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"sync"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
) )
@@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/api/delete": case "/api/delete":
return false, s.handleDelete(rec, r) return false, s.handleDelete(rec, r)
case "/api/pull":
return false, s.handlePull(rec, r)
default: default:
if s.Fallback != nil { if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r) s.Fallback.ServeHTTP(rec, r)
@@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return s.Prune() return s.Prune()
} }
type progressUpdateJSON struct {
Status string `json:"status"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
}
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
defer maybeFlush()
var mu sync.Mutex
enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
mu.Lock()
defer mu.Unlock()
// TODO(bmizerany): coalesce these updates; writing per
// update is expensive
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Status: "pulling",
Total: l.Size,
Completed: n,
})
},
})
done := make(chan error, 1)
go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model())
}()
func() {
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
mu.Lock()
maybeFlush()
mu.Unlock()
case err := <-done:
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
}
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
}
}
}()
return nil
}
func decodeUserJSON[T any](r io.Reader) (T, error) { func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T var v T
err := json.NewDecoder(r).Decode(&v) err := json.NewDecoder(r).Decode(&v)

View File

@@ -1,17 +1,27 @@
package registry package registry
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io"
"io/fs"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"regexp" "regexp"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
"golang.org/x/tools/txtar"
_ "embed"
) )
type panicTransport struct{} type panicTransport struct{}
@@ -30,7 +40,7 @@ type bytesResetter interface {
Reset() Reset()
} }
func newTestServer(t *testing.T) *Local { func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
t.Helper() t.Helper()
dir := t.TempDir() dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models")) err := os.CopyFS(dir, os.DirFS("testdata/models"))
@@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := panicOnRoundTrip
if upstreamRegistry != nil {
s := httptest.NewTLSServer(upstreamRegistry)
t.Cleanup(s.Close)
tr := s.Client().Transport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
}
client = &http.Client{Transport: tr}
}
rc := &ollama.Registry{ rc := &ollama.Registry{
Cache: c, Cache: c,
HTTPClient: panicOnRoundTrip, HTTPClient: client,
Mask: "example.com/library/_:latest",
} }
l := &Local{ l := &Local{
Client: rc, Client: rc,
Logger: testutil.Slogger(t), Logger: testutil.Slogger(t),
@@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
func TestServerDelete(t *testing.T) { func TestServerDelete(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)
s := newTestServer(t) s := newTestServer(t, nil)
_, err := s.Client.ResolveLocal("smol") _, err := s.Client.ResolveLocal("smol")
check(err) check(err)
@@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) {
} }
} }
//go:embed testdata/registry.txt
var registryTXT []byte
var registryFS = sync.OnceValue(func() fs.FS {
// Txtar gets hung up on \r\n line endings, so we need to convert them
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
}
return fsys
})
func TestServerPull(t *testing.T) {
modelsHandler := http.FileServerFS(registryFS())
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/library/BOOM/manifests/latest":
w.WriteHeader(999)
io.WriteString(w, `{"error": "boom"}`)
case "/v2/library/unknown/manifests/latest":
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving file: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Fatalf("! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Fatalf("! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"}
`)
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "POST", "/api/pull", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "POST", "/api/pull", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""}
!verifying
!writing
!success
`)
}
func TestServerUnknownPath(t *testing.T) { func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t) s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`) got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found") checkErrorResponse(t, got, 404, "not_found", "not found")
} }

View File

@@ -0,0 +1,22 @@
-- v2/library/smol/manifests/latest --
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": {
"mediaType": "application/vnd.docker.container.image.v1+json",
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
"size": 3
},
"layers": [
{
"mediaType": "application/vnd.ollama.image.model",
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
"size": 5
}
]
}
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
GGUF
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
{}

View File

@@ -42,6 +42,12 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
@@ -1173,6 +1179,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/tags", s.ListHandler) r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler) r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler) r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler)
// Create // Create
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
@@ -1194,16 +1201,19 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
// wrap old with new if rc != nil {
rs := &registry.Local{ // wrap old with new
Client: rc, rs := &registry.Local{
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() Client: rc,
Fallback: r, Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
Prune: PruneLayers, Prune: PruneLayers,
}
return rs, nil
} }
return rs, nil return r, nil
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
@@ -1258,15 +1268,20 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()} s := &Server{addr: ln.Addr()}
rc, err := ollama.DefaultRegistry() var rc *ollama.Registry
if err != nil { if useClient2 {
return err var err error
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
} }
h, err := s.GenerateRoutes(rc) h, err := s.GenerateRoutes(rc)
if err != nil { if err != nil {
return err return err
} }
http.Handle("/", h) http.Handle("/", h)
ctx, done := context.WithCancel(context.Background()) ctx, done := context.WithCancel(context.Background())