mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
@@ -76,6 +76,7 @@ Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||
|
||||
@@ -361,9 +361,9 @@ type CopyRequest struct {
|
||||
// PullRequest is the request passed to [Client.Pull].
|
||||
type PullRequest struct {
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
||||
Username string `json:"username"` // Deprecated: ignored
|
||||
Password string `json:"password"` // Deprecated: ignored
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
|
||||
@@ -81,9 +81,11 @@ help you keep up to date.
|
||||
|
||||
If you'd like to install or integrate Ollama as a service, a standalone
|
||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
||||
Ollama in existing applications, or running it as a system service via `ollama
|
||||
serve` with tools such as [NSSM](https://nssm.cc/).
|
||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||
same directory. This allows for embedding Ollama in existing applications, or
|
||||
running it as a system service via `ollama serve` with tools such as
|
||||
[NSSM](https://nssm.cc/).
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||
|
||||
3
go.mod
3
go.mod
@@ -24,7 +24,7 @@ require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
golang.org/x/tools v0.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -44,6 +44,7 @@ require (
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
causal bool
|
||||
windowSize int32
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
@@ -42,6 +43,12 @@ type Causal struct {
|
||||
// locations in the cache that are needed for this batch
|
||||
curCellRange cellRange
|
||||
|
||||
// curSequences is the sequences corresponding to this pass's entries in the cache
|
||||
curSequences []int
|
||||
|
||||
// curPositions is the positions corresponding to this pass's entries in the cache
|
||||
curPositions []int32
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// for each possible location in the cache, stores the position and set of sequences
|
||||
@@ -55,8 +62,8 @@ type Causal struct {
|
||||
|
||||
shiftFn shiftFn
|
||||
backend ml.Backend
|
||||
cacheCtx ml.Context
|
||||
keys, values []ml.Tensor
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type cacheCell struct {
|
||||
@@ -70,11 +77,25 @@ type cellRange struct {
|
||||
}
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: math.MaxInt32,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{windowSize: windowSize, shiftFn: shift}
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
@@ -103,7 +124,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
@@ -115,11 +135,15 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
c.cacheCtx.Close()
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||
c.curBatchSize = len(positions)
|
||||
c.curSequences = seqs
|
||||
c.curPositions = positions
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
@@ -158,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.curMask, err = c.buildMask(ctx, positions, seqs)
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -199,7 +223,7 @@ func roundUp(length, pad int) int {
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
@@ -211,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||
c.cells[j].pos < positions[i]-c.windowSize {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
@@ -224,13 +249,13 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
|
||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
}
|
||||
@@ -239,13 +264,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
for i := range c.keys {
|
||||
if c.keys[i] == nil {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key := c.keys[i]
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
@@ -305,7 +328,7 @@ func (c *Causal) defrag() {
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := ctx.MaxTensors() / (6 * layers)
|
||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
@@ -377,14 +400,29 @@ func (c *Causal) defrag() {
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
if layer >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
||||
// This state carries over to future forward passes. The default value is true.
|
||||
//
|
||||
// ctx may be set to nil if this is called from outside of a forward pass, for
|
||||
// example, when initializing the cache.
|
||||
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
||||
if c.causal != causal {
|
||||
c.causal = causal
|
||||
|
||||
if ctx != nil {
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
if err != nil {
|
||||
// This error should never occur because we have previously built a mask with the same shape
|
||||
panic(fmt.Errorf("SetCausal: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
@@ -433,13 +471,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -501,7 +545,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
}
|
||||
|
||||
kShift, err := ctx.FromIntSlice(offsets, len(offsets))
|
||||
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -303,6 +303,10 @@ func (b *testBackend) NewContext() ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
func (b *testBackend) SystemInfo() string {
|
||||
return "not implemented"
|
||||
}
|
||||
@@ -346,11 +350,15 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) MaxTensors() int {
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
|
||||
@@ -35,13 +35,17 @@ type EncoderCache struct {
|
||||
encoderPos int32
|
||||
|
||||
// ** cache data storage **
|
||||
|
||||
cacheCtx ml.Context
|
||||
keys, values []ml.Tensor
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
}
|
||||
|
||||
func NewEncoderCache() *EncoderCache {
|
||||
return &EncoderCache{}
|
||||
return &EncoderCache{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
@@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
|
||||
c.cacheCtx = backend.NewContext()
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
@@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Close() {
|
||||
c.cacheCtx.Close()
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||
@@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetLayer(layer int) {
|
||||
if layer >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
@@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
|
||||
2
llama/llama.cpp/src/llama-vocab.cpp
vendored
2
llama/llama.cpp/src/llama-vocab.cpp
vendored
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|
||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
||||
if (precompiled_charsmap_keyidx != -1) {
|
||||
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
|
||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
||||
#ifdef IS_BIG_ENDIAN
|
||||
|
||||
64
llama/patches/0019-fix-string-arr-kv-loading.patch
Normal file
64
llama/patches/0019-fix-string-arr-kv-loading.patch
Normal file
@@ -0,0 +1,64 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Wed, 5 Mar 2025 17:41:07 -0800
|
||||
Subject: [PATCH] fix string arr kv loading
|
||||
|
||||
---
|
||||
ggml/include/gguf.h | 1 +
|
||||
ggml/src/gguf.cpp | 7 +++++--
|
||||
src/llama-vocab.cpp | 2 +-
|
||||
3 files changed, 7 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
|
||||
index 79ee2020..3efb22f0 100644
|
||||
--- a/ggml/include/gguf.h
|
||||
+++ b/ggml/include/gguf.h
|
||||
@@ -114,6 +114,7 @@ extern "C" {
|
||||
// get raw pointer to the first element of the array with the given key_id
|
||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
||||
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
// get ith C string from array with given key_id
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
||||
index ab13669c..f75b923f 100644
|
||||
--- a/ggml/src/gguf.cpp
|
||||
+++ b/ggml/src/gguf.cpp
|
||||
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
||||
|
||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||
return ctx->kv[key_id].data.data();
|
||||
}
|
||||
|
||||
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
|
||||
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
+ return ctx->kv[key_id].data.size();
|
||||
+}
|
||||
+
|
||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
||||
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
||||
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||
return ctx->kv[key_id].data.data();
|
||||
}
|
||||
|
||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||
index c7ff28be..7a185443 100644
|
||||
--- a/src/llama-vocab.cpp
|
||||
+++ b/src/llama-vocab.cpp
|
||||
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|
||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
||||
if (precompiled_charsmap_keyidx != -1) {
|
||||
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
|
||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
||||
#ifdef IS_BIG_ENDIAN
|
||||
@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
}
|
||||
if s.textProcessor != nil {
|
||||
tokens, err := s.textProcessor.Encode(content)
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type Backend interface {
|
||||
Config() Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
NewContextSize(size int) Context
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
@@ -99,8 +100,17 @@ type Context interface {
|
||||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating input tensors
|
||||
Input() Context
|
||||
|
||||
// Output returns a context appropriate for creating output tensors
|
||||
Output() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
@@ -205,7 +215,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16:
|
||||
case DTypeF16, DTypeQ80, DTypeQ40:
|
||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
@@ -273,5 +283,7 @@ const (
|
||||
DTypeOther DType = iota
|
||||
DTypeF32
|
||||
DTypeF16
|
||||
DTypeQ80
|
||||
DTypeQ40
|
||||
DTypeI32
|
||||
)
|
||||
|
||||
@@ -9,67 +9,53 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type device struct {
|
||||
d *C.struct_ggml_backend_device
|
||||
}
|
||||
|
||||
func (d device) LogValue() slog.Value {
|
||||
var free, total uint64
|
||||
C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
|
||||
|
||||
kind := "unknown"
|
||||
switch C.ggml_backend_dev_type(d.d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
kind = "cpu"
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
kind = "gpu"
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
kind = "accel"
|
||||
}
|
||||
|
||||
return slog.GroupValue(
|
||||
slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
|
||||
slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
|
||||
slog.String("kind", kind),
|
||||
slog.String("free", format.HumanBytes2(free)),
|
||||
slog.String("total", format.HumanBytes2(total)),
|
||||
)
|
||||
}
|
||||
|
||||
var devices = sync.OnceValue(func() []device {
|
||||
func devices() []*C.struct_ggml_backend_device {
|
||||
ggml.OnceLoad()
|
||||
|
||||
s := make([]device, C.ggml_backend_dev_count())
|
||||
for i := range s {
|
||||
s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
|
||||
ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
|
||||
for i := range ds {
|
||||
ds[i] = C.ggml_backend_dev_get(C.size_t(i))
|
||||
}
|
||||
|
||||
return s
|
||||
})
|
||||
return ds
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
meta *fs.GGML
|
||||
sched *C.struct_ggml_backend_sched
|
||||
tensors map[string]*C.struct_ggml_tensor
|
||||
|
||||
// input is the backend used for inputs
|
||||
input *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// output is the backend used for outputs
|
||||
output *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
flashAttention bool
|
||||
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
maxGraphNodes int
|
||||
}
|
||||
|
||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
@@ -88,107 +74,310 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
|
||||
var cpus, gpus []Context
|
||||
type deviceBufferType struct {
|
||||
d *C.struct_ggml_backend_device
|
||||
bts []*C.struct_ggml_backend_buffer_type
|
||||
}
|
||||
|
||||
var cpus, accels, gpus []*C.struct_ggml_backend_device
|
||||
for _, d := range devices() {
|
||||
switch C.ggml_backend_dev_type(d.d) {
|
||||
switch C.ggml_backend_dev_type(d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
if len(cpus) == 0 {
|
||||
// only the first cpu device should be used
|
||||
cpus = append(cpus, d)
|
||||
}
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
accels = append(accels, d)
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
gpus = append(gpus, d)
|
||||
}
|
||||
}
|
||||
|
||||
// create list of buffer types for the cpu
|
||||
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
|
||||
for _, d := range append(accels, append(gpus, cpus...)...) {
|
||||
switch C.ggml_backend_dev_type(d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
slog.Info("cpu", "device", d)
|
||||
cpus = append(cpus, Context{
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
|
||||
no_alloc: true,
|
||||
}),
|
||||
backend: C.ggml_backend_dev_init(d.d, nil),
|
||||
})
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
slog.Info("gpu", "device", d)
|
||||
gpus = append(gpus, Context{
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
|
||||
no_alloc: true,
|
||||
}),
|
||||
backend: C.ggml_backend_dev_init(d.d, nil),
|
||||
})
|
||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
||||
}
|
||||
}
|
||||
|
||||
ctxFunc := func(s []Context) (*Context, error) {
|
||||
for _, e := range s {
|
||||
return &e, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no devices available")
|
||||
}
|
||||
|
||||
tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
c, err := ctxFunc(append(gpus, cpus...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func() {
|
||||
tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
|
||||
|
||||
cname := C.CString(t.Name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
tensors[t] = c
|
||||
}()
|
||||
}
|
||||
|
||||
for _, b := range append(gpus, cpus...) {
|
||||
C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
||||
|
||||
var g errgroup.Group
|
||||
for t, c := range tensors {
|
||||
g.Go(func() error {
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != int(t.Size()) {
|
||||
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
|
||||
}
|
||||
|
||||
cname := C.CString(t.Name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
|
||||
return nil
|
||||
// create list of buffer types for each gpu
|
||||
var gpuDeviceBufferTypes []deviceBufferType
|
||||
for _, d := range gpus {
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||
d: d,
|
||||
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
useDefaultSplit := true
|
||||
for _, s := range params.TensorSplit {
|
||||
if s != 0 {
|
||||
useDefaultSplit = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// calculate splits
|
||||
splits := make([]float32, len(gpus))
|
||||
if useDefaultSplit {
|
||||
// default: split on free memory
|
||||
for i := range splits {
|
||||
var free, total C.size_t
|
||||
C.ggml_backend_dev_memory(gpus[i], &free, &total)
|
||||
splits[i] = float32(free)
|
||||
}
|
||||
} else {
|
||||
splits = params.TensorSplit
|
||||
}
|
||||
|
||||
var sum float32
|
||||
// cumulative sum of all splits
|
||||
for i := range splits {
|
||||
sum += splits[i]
|
||||
splits[i] = sum
|
||||
}
|
||||
|
||||
// normalize splits
|
||||
for i := range splits {
|
||||
splits[i] /= sum
|
||||
}
|
||||
|
||||
// inputs always use cpu
|
||||
input := cpuDeviceBufferType
|
||||
|
||||
blocks := int(meta.KV().BlockCount())
|
||||
|
||||
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
||||
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
||||
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
||||
assignLayer := func(i int) deviceBufferType {
|
||||
if i < gpuRangeStart || i >= gpuRangeStop {
|
||||
return cpuDeviceBufferType
|
||||
}
|
||||
|
||||
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
|
||||
if index < 0 || index >= len(gpuDeviceBufferTypes) {
|
||||
return cpuDeviceBufferType
|
||||
}
|
||||
|
||||
return gpuDeviceBufferTypes[index]
|
||||
}
|
||||
|
||||
// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
|
||||
layers := make([]deviceBufferType, blocks)
|
||||
for i := range layers {
|
||||
layers[i] = assignLayer(i)
|
||||
}
|
||||
|
||||
// outputs are assigned iff allowed by splits and configured number of gpu layers
|
||||
output := assignLayer(blocks)
|
||||
|
||||
maxTensors := len(meta.Tensors().Items())
|
||||
maxTensors += 1
|
||||
// each layer has at most 2 extra tensors for rope operations
|
||||
maxTensors += blocks * 2
|
||||
|
||||
type tensor struct {
|
||||
source *fs.Tensor
|
||||
target string
|
||||
}
|
||||
|
||||
// some tensors are mapped to different names so keep a list
|
||||
targets := make(map[string][]string)
|
||||
|
||||
// contexts are shared by tensors of the same buffer type
|
||||
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
|
||||
for _, bt := range bts {
|
||||
if _, ok := ctxs[bt]; !ok {
|
||||
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
|
||||
no_alloc: true,
|
||||
})
|
||||
}
|
||||
|
||||
targets[t.source.Name] = append(targets[t.source.Name], t.target)
|
||||
|
||||
name := t.source.Name
|
||||
if t.target != "" {
|
||||
name = t.target
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
|
||||
return tt
|
||||
}
|
||||
|
||||
tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
//nolint:staticcheck // TODO: check if buffer type supports this tensor
|
||||
return tt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
contains := func(s string, parts ...string) bool {
|
||||
split := strings.Split(s, ".")
|
||||
for _, part := range parts {
|
||||
if slices.Contains(split, part) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
switch {
|
||||
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
default:
|
||||
layerIndex := -1
|
||||
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
|
||||
if i, err := strconv.Atoi(fields[0]); err == nil {
|
||||
layerIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if layerIndex >= 0 {
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts)
|
||||
} else {
|
||||
// this is a repeating tensor that doesn't explicitly associated with a layer so
|
||||
// duplicate it for each layer
|
||||
for i, layer := range layers {
|
||||
createTensor(tensor{
|
||||
source: t,
|
||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||
}, layer.bts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// allocate buffers for each context
|
||||
bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
|
||||
for bt, c := range ctxs {
|
||||
if C.ggml_get_first_tensor(c) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||
bbs[c] = b
|
||||
}
|
||||
|
||||
for bs := range maps.Values(bbs) {
|
||||
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
}
|
||||
|
||||
// map tensor names to tensors for easy lookup later
|
||||
tensors := make(map[string]*C.struct_ggml_tensor)
|
||||
for _, c := range ctxs {
|
||||
for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
|
||||
tensors[C.GoString(C.ggml_get_name(t))] = t
|
||||
}
|
||||
}
|
||||
|
||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
||||
var g errgroup.Group
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
for _, target := range targets[t.Name] {
|
||||
g.Go(func() error {
|
||||
if target == "" {
|
||||
target = t.Name
|
||||
}
|
||||
|
||||
tt, ok := tensors[target]
|
||||
if !ok {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if g.Wait() != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
|
||||
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
|
||||
for i, c := range append(gpus, cpus...) {
|
||||
backends[i] = c.backend
|
||||
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
|
||||
// map devices to backend buffer types so new tensors can be assigned to the correct device
|
||||
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
|
||||
|
||||
// create backends and buffer types used for the compute graph scheduler
|
||||
var schedBackends []*C.struct_ggml_backend
|
||||
var schedBufts []*C.struct_ggml_backend_buffer_type
|
||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||
b := C.ggml_backend_dev_init(d, nil)
|
||||
bt := C.ggml_backend_get_default_buffer_type(b)
|
||||
if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
|
||||
// use the first gpu host buffer type for gpu if possible
|
||||
if hbt := C.ggml_backend_dev_host_buffer_type(gpus[0]); hbt != nil {
|
||||
bt = hbt
|
||||
}
|
||||
}
|
||||
|
||||
deviceBufferTypes[d] = bt
|
||||
|
||||
schedBackends = append(schedBackends, b)
|
||||
schedBufts = append(schedBufts, bt)
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
if C.ggml_backend_is_cpu(b) {
|
||||
// set number of threads for cpu backend
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
|
||||
}
|
||||
}
|
||||
|
||||
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
|
||||
return &Backend{
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
tensors: tensors,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
C.int(len(backends)),
|
||||
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||
for i, layer := range layers {
|
||||
m[i] = deviceBufferTypes[layer.d]
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
maxGraphNodes: maxGraphNodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -201,36 +390,29 @@ func (b *Backend) Config() ml.Config {
|
||||
}
|
||||
|
||||
func (b *Backend) Get(name string) ml.Tensor {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
for _, c := range append(b.gpus, b.cpus...) {
|
||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
if t, ok := b.tensors[name]; ok {
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) NewContext() ml.Context {
|
||||
nodes := max(8192, len(b.meta.Tensors().Items())*5)
|
||||
c := C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_buffer: nil,
|
||||
mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
|
||||
no_alloc: true,
|
||||
})
|
||||
return b.NewContextSize(b.maxGraphNodes)
|
||||
}
|
||||
|
||||
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
|
||||
for i, c := range append(b.gpus, b.cpus...) {
|
||||
backends[i] = c.backend
|
||||
func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
if n > b.maxGraphNodes {
|
||||
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
|
||||
}
|
||||
|
||||
return &Context{
|
||||
b: b,
|
||||
ctx: c,
|
||||
backend: backends[0],
|
||||
nodes: nodes,
|
||||
b: b,
|
||||
maxGraphNodes: n,
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
|
||||
no_alloc: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,17 +425,60 @@ func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
b *Backend
|
||||
ctx *C.struct_ggml_context
|
||||
backend *C.struct_ggml_backend
|
||||
b *Backend
|
||||
|
||||
ctx *C.struct_ggml_context
|
||||
graph *C.struct_ggml_cgraph
|
||||
nodes int
|
||||
|
||||
// buft is the buffer type used for new tensors
|
||||
buft *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this context
|
||||
maxGraphNodes int
|
||||
}
|
||||
|
||||
func (c Context) Input() ml.Context {
|
||||
if c.b.input != nil {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: c.b.input,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Output() ml.Context {
|
||||
if c.b.output != nil {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: c.b.output,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Layer(i int) ml.Context {
|
||||
if buft, ok := c.b.layers[i]; ok {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: buft,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
|
||||
}
|
||||
|
||||
for _, tensor := range tensors {
|
||||
@@ -263,7 +488,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
func (c Context) Compute(tensors ...ml.Tensor) {
|
||||
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
|
||||
@@ -282,21 +507,48 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) MaxTensors() int {
|
||||
return c.nodes
|
||||
func (c Context) MaxGraphNodes() int {
|
||||
return c.maxGraphNodes
|
||||
}
|
||||
|
||||
func shapeToGGML(shape []int) *C.int64_t {
|
||||
sh := make([]C.int64_t, len(shape))
|
||||
for i, s := range shape {
|
||||
sh[i] = (C.int64_t)(s)
|
||||
sh[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
return &sh[0]
|
||||
}
|
||||
|
||||
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||
if len(shape) < 1 || len(shape) > 4 {
|
||||
func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
if c.buft == nil {
|
||||
panic("set Input, Output, or Layer before creating tensors")
|
||||
}
|
||||
|
||||
var cdtype uint32
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
cdtype = C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
cdtype = C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
cdtype = C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
cdtype = C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
cdtype = C.GGML_TYPE_I32
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
@@ -306,41 +558,28 @@ func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeF16:
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
if zero {
|
||||
C.ggml_set_zero(t)
|
||||
}
|
||||
return &Tensor{b: ctx.b, t: t}
|
||||
return &Tensor{b: c.b, t: t}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, false, shape)
|
||||
return c.newTensor(dtype, shape)
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, true, shape)
|
||||
t := c.newTensor(dtype, shape)
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
return t
|
||||
}
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||
n := len(s)
|
||||
|
||||
if n == 0 {
|
||||
var shape C.int64_t = 0
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
@@ -348,22 +587,36 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
|
||||
return fmt.Errorf("invalid shape: %v", shape)
|
||||
}
|
||||
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
return fromSlice(c, s, shape, C.GGML_TYPE_F32)
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c *Context) Close() {
|
||||
@@ -431,6 +684,10 @@ func (t *Tensor) DType() ml.DType {
|
||||
return ml.DTypeF32
|
||||
case C.GGML_TYPE_F16:
|
||||
return ml.DTypeF16
|
||||
case C.GGML_TYPE_Q8_0:
|
||||
return ml.DTypeQ80
|
||||
case C.GGML_TYPE_Q4_0:
|
||||
return ml.DTypeQ40
|
||||
case C.GGML_TYPE_I32:
|
||||
return ml.DTypeI32
|
||||
default:
|
||||
|
||||
1
ml/backend/ggml/ggml/include/gguf.h
vendored
1
ml/backend/ggml/ggml/include/gguf.h
vendored
@@ -114,6 +114,7 @@ extern "C" {
|
||||
// get raw pointer to the first element of the array with the given key_id
|
||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
// get ith C string from array with given key_id
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||
|
||||
7
ml/backend/ggml/ggml/src/gguf.cpp
vendored
7
ml/backend/ggml/ggml/src/gguf.cpp
vendored
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
||||
|
||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||
return ctx->kv[key_id].data.data();
|
||||
}
|
||||
|
||||
size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
return ctx->kv[key_id].data.size();
|
||||
}
|
||||
|
||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
||||
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
||||
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||
return ctx->kv[key_id].data.data();
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
@@ -22,14 +21,40 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
)
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
Token int32
|
||||
|
||||
// Multimodal is opaque data representing a non-text
|
||||
// element such as an image (or part of one if the image
|
||||
// can be processed in pieces). It may be either together
|
||||
// with Token or on its own.
|
||||
Multimodal any
|
||||
|
||||
// MultimodalHash is a unique representation of the data
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
// together with an index into the slice of Inputs with the
|
||||
// corresponding token. Note that the index is not the same
|
||||
// as the position - to find that use the index with the
|
||||
// Positions slice.
|
||||
type MultimodalIndex struct {
|
||||
Index int
|
||||
Multimodal any
|
||||
}
|
||||
|
||||
// Options contains the inputs for a model forward pass
|
||||
type Options struct {
|
||||
Inputs []int32
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
|
||||
Images []image.Image
|
||||
Inputs []int32
|
||||
Multimodal []MultimodalIndex
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -59,6 +84,37 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
// generates an output (typically an embedding) that can be used by the model.
|
||||
//
|
||||
// The return value is most typically an ml.Tensor, however, different
|
||||
// type are possible, such as an object containing a tensor plus
|
||||
// additional metadata, a slice of tensors or even just the original input.
|
||||
//
|
||||
// The result may be cached by the runner.
|
||||
EncodeMultimodal(ml.Context, []byte) (any, error)
|
||||
|
||||
// PostTokenize is called after tokenization to allow the model to edit the
|
||||
// input stream to correctly arrange multimodal elements.
|
||||
//
|
||||
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
|
||||
// in the order that the user provided them. Each element of the slice will be
|
||||
// either a single token or single multimodal object.
|
||||
//
|
||||
// The model must ensure that inputs are stored according to how they will be
|
||||
// processed and stored in the cache. For example, Llava-style models should insert
|
||||
// placeholder tokens equal to the feature size of the corresponding image with
|
||||
// the image itself attached to and split across these tokens. When Forward is called
|
||||
// a partial subset of these tokens may be submitted according to the batch size.
|
||||
//
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize(ml.Context, []Input) ([]Input, error)
|
||||
}
|
||||
|
||||
var models = make(map[string]func(ml.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
@@ -66,10 +65,11 @@ func New(c ml.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
@@ -78,11 +78,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -95,7 +95,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -138,17 +138,17 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -56,54 +61,92 @@ func New(c ml.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.maxNumTiles,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions := make([]int32, 1601)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
|
||||
var images []model.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
inputs[i].MultimodalHash = fnvHash.Sum64()
|
||||
}
|
||||
images = nil
|
||||
}
|
||||
} else {
|
||||
images = append(images, inputs[i])
|
||||
inputs[i].Token = -1
|
||||
}
|
||||
}
|
||||
|
||||
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
|
||||
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if opts.Images != nil {
|
||||
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.maxNumTiles,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions := make([]int32, 1601)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
|
||||
if opts.Multimodal != nil {
|
||||
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
|
||||
}
|
||||
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
@@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the causal cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
||||
}
|
||||
|
||||
type TextModelOptions struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
Encode(string) ([]int32, error)
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
}
|
||||
@@ -144,7 +144,7 @@ type merge struct {
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
@@ -177,7 +177,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -201,7 +200,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -275,14 +273,13 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if bpe.vocab.AddBOS {
|
||||
if ids[0] == bpe.vocab.BOS {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
||||
@@ -329,6 +326,5 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world")
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want)
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Encode(string(bts))
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts))
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
||||
slots := make([]InputCacheSlot, numSlots)
|
||||
|
||||
for i := range slots {
|
||||
slots[i] = InputCacheSlot{
|
||||
Id: i,
|
||||
Inputs: make([]input, 0),
|
||||
}
|
||||
slots[i] = InputCacheSlot{Id: i}
|
||||
}
|
||||
|
||||
cache := model.Config().Cache
|
||||
@@ -62,9 +58,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
||||
func kvCacheTypeFromStr(s string) ml.DType {
|
||||
switch s {
|
||||
case "q8_0":
|
||||
panic("kv cache quantization not yet implemented")
|
||||
return ml.DTypeQ80
|
||||
case "q4_0":
|
||||
panic("kv cache quantization not yet implemented")
|
||||
return ml.DTypeQ40
|
||||
default:
|
||||
return ml.DTypeF16
|
||||
}
|
||||
@@ -83,7 +79,7 @@ type InputCacheSlot struct {
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input
|
||||
Inputs []model.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
@@ -92,7 +88,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
@@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
@@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input, longest)
|
||||
oldestSlot.Inputs = make([]model.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
@@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input, b []input) int32 {
|
||||
func countCommonPrefix(a []model.Input, b []model.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
@@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
|
||||
break
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a[i], b[i]) {
|
||||
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestCountCommon(t *testing.T) {
|
||||
@@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input
|
||||
t2 []input
|
||||
t1 []model.Input
|
||||
t2 []model.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input{{token: 1}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t1: []model.Input{{Token: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input{{image: imgA}},
|
||||
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
|
||||
t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input{{token: 1}, {image: imgA}},
|
||||
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
|
||||
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input{},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t1: []model.Input{},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input{},
|
||||
t2: []input{},
|
||||
t1: []model.Input{},
|
||||
t2: []model.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
@@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input
|
||||
prompt []model.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
@@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{},
|
||||
Inputs: []model.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
Inputs: []model.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}},
|
||||
prompt: []model.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
@@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
prompt: []model.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
Inputs: []model.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}},
|
||||
prompt: []model.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
Inputs: []model.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input{{token: 1}},
|
||||
prompt: []model.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
@@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}, {token: 3}},
|
||||
prompt: []model.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}},
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
prompt: []model.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"hash/maphash"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -33,22 +32,19 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
// input is an element of the prompt to process, either a token or an image
|
||||
type input struct {
|
||||
token int32
|
||||
|
||||
image image.Image
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
||||
// multimodal embeddings
|
||||
ctx ml.Context
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input
|
||||
inputs []model.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input
|
||||
pendingInputs []model.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
@@ -101,8 +97,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
ctx := s.model.Backend().NewContext()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
inputs, err := s.inputs(ctx, prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@@ -128,6 +125,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctx: ctx,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@@ -146,28 +144,31 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
|
||||
var inputs []model.Input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
// TODO(jessegross): This can sometimes trigger for matching text in the
|
||||
// user's prompt. We previously tried to avoid it by only looking for images
|
||||
// on image models. We don't have a clear indication now but it would be better
|
||||
// to properly escape it in any case.
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
|
||||
|
||||
if visionModel {
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
} else {
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part)
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input{token: t})
|
||||
inputs = append(inputs, model.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
@@ -186,12 +187,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputs = append(inputs, input{image: image})
|
||||
s.multimodalHash.Reset()
|
||||
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
||||
imageHash := s.multimodalHash.Sum64()
|
||||
|
||||
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
|
||||
if visionModel && postTokenize {
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,6 +252,10 @@ type Server struct {
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
}
|
||||
|
||||
func (s *Server) allNil() bool {
|
||||
@@ -283,6 +301,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
seq.ctx.Close()
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
@@ -311,7 +330,6 @@ func (s *Server) processBatch() error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options model.Options
|
||||
imgSeq := -1
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
@@ -330,7 +348,7 @@ func (s *Server) processBatch() error {
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input{}
|
||||
seq.cache.Inputs = []model.Input{}
|
||||
}
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
@@ -349,25 +367,21 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
// TODO(jessegross): Image inputs need to be rethought - it's
|
||||
// it doesn't work well for different types of models or multiple sequences
|
||||
if input.image != nil {
|
||||
if len(seq.pendingInputs) != len(options.Images) {
|
||||
break
|
||||
}
|
||||
|
||||
if imgSeq != seqIdx && imgSeq != -1 {
|
||||
s.nextSeq = seqIdx
|
||||
break
|
||||
}
|
||||
|
||||
imgSeq = seqIdx
|
||||
options.Images = append(options.Images, input.image)
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
continue
|
||||
// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
|
||||
// to the encoder cache.
|
||||
//
|
||||
// Break the batch when switching from text to images so that images are always at the beginning.
|
||||
if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
|
||||
(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
|
||||
s.nextSeq = seqIdx
|
||||
break
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, input.Token)
|
||||
if input.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, input.token)
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
|
||||
@@ -403,7 +417,7 @@ func (s *Server) processBatch() error {
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input{}
|
||||
seq.pendingInputs = []model.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
@@ -422,6 +436,7 @@ func (s *Server) processBatch() error {
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, "")
|
||||
continue
|
||||
}
|
||||
@@ -449,7 +464,7 @@ func (s *Server) processBatch() error {
|
||||
return err
|
||||
}
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
seq.inputs = []model.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
@@ -575,11 +590,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
)
|
||||
|
||||
if req.Grammar != "" {
|
||||
panic("grammars are not yet supported")
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -2,76 +2,103 @@ package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// Sampler is not thread-safe. Each goroutine should have its own instance
|
||||
type Sampler interface {
|
||||
Sample([]float32) (int32, error)
|
||||
}
|
||||
|
||||
// logit represents information about a single token during sampling
|
||||
type logit struct {
|
||||
id int32 // The token's unique identifier
|
||||
value float32 // The raw logit or probability from the model
|
||||
}
|
||||
|
||||
type weighted struct {
|
||||
src rand.Source
|
||||
transforms []Transform
|
||||
rng *rand.Rand
|
||||
tokens []logit
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
}
|
||||
|
||||
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
|
||||
func Weighted(seed *uint64, transforms ...Transform) Sampler {
|
||||
var src rand.Source
|
||||
if seed != nil {
|
||||
src = rand.NewSource(*seed)
|
||||
func (s *weighted) Sample(logits []float32) (int32, error) {
|
||||
if len(s.tokens) < len(logits) {
|
||||
s.tokens = make([]logit, len(logits))
|
||||
}
|
||||
return weighted{src: src, transforms: transforms}
|
||||
}
|
||||
|
||||
func (s weighted) Sample(logits []float32) (int32, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
tokens := s.tokens[:len(logits)]
|
||||
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = v
|
||||
}
|
||||
|
||||
for _, t := range s.transforms {
|
||||
logits64 = t.Apply(logits64)
|
||||
// Tokens are sorted by logits in TopK or SortTokens
|
||||
if s.topK > 0 {
|
||||
tokens = topK(tokens, s.topK)
|
||||
} else {
|
||||
sortLogits(tokens)
|
||||
}
|
||||
|
||||
logitsCopy := make([]float64, 0, len(logits))
|
||||
indices := make([]int, 0, len(logits))
|
||||
for i, logit := range logits64 {
|
||||
if !math.IsInf(logit, -1) {
|
||||
logitsCopy = append(logitsCopy, logit)
|
||||
indices = append(indices, i)
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
tokens = softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return -1, errors.New("no valid logits found for weighted sampling")
|
||||
}
|
||||
|
||||
var r float32
|
||||
if s.rng != nil {
|
||||
r = s.rng.Float32()
|
||||
} else {
|
||||
r = rand.Float32()
|
||||
}
|
||||
|
||||
// Calculate cumulative sum of probabilities
|
||||
var sum float32
|
||||
for i := range tokens {
|
||||
sum += tokens[i].value
|
||||
tokens[i].value = sum
|
||||
}
|
||||
r *= tokens[len(tokens)-1].value
|
||||
|
||||
idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
|
||||
// Compare cumulative probabilities
|
||||
if token.value < target {
|
||||
return -1
|
||||
}
|
||||
// First token that exceeds target
|
||||
return 1
|
||||
})
|
||||
|
||||
if idx >= len(tokens) {
|
||||
idx = len(tokens) - 1
|
||||
}
|
||||
|
||||
if len(logitsCopy) == 0 {
|
||||
return -1, errors.New("no valid logits found for weighed sampling")
|
||||
}
|
||||
|
||||
probs := softmax(logitsCopy)
|
||||
w := sampleuv.NewWeighted(probs, s.src)
|
||||
if idx, ok := w.Take(); ok {
|
||||
return int32(indices[idx]), nil
|
||||
}
|
||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||
return tokens[idx].id, nil
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
// Sample returns the index of the maximum value in logits.
|
||||
// Greedy sample returns the index of the maximum value in logits.
|
||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("no logits provided for greedy sampling")
|
||||
}
|
||||
|
||||
maxIdx := 0
|
||||
for i := range logits {
|
||||
if logits[i] > logits[maxIdx] {
|
||||
maxVal := logits[0]
|
||||
for i := 1; i < len(logits); i++ {
|
||||
if logits[i] > maxVal {
|
||||
maxVal = logits[i]
|
||||
maxIdx = i
|
||||
}
|
||||
}
|
||||
@@ -80,41 +107,40 @@ func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
|
||||
if temperature == 0 {
|
||||
return Greedy(), nil
|
||||
return &greedy{}
|
||||
}
|
||||
|
||||
if temperature < 0 || temperature > 2 {
|
||||
return nil, errors.New("temperature must be between 0 and 2")
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
// Use original seed for sequence
|
||||
sequence := uint64(seed)
|
||||
// Use golden ratio hash to generate statistically independent seeds
|
||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||
}
|
||||
temperature = max(temperature, 1)
|
||||
|
||||
if topP < 0.0 {
|
||||
topP = 0.0
|
||||
}
|
||||
if topP >= 1.0 {
|
||||
topP = 1.0
|
||||
}
|
||||
|
||||
transforms := []Transform{Temperature(temperature)}
|
||||
|
||||
if topK != 0 {
|
||||
if topK <= 0 {
|
||||
return nil, errors.New("topK must be greater than 0")
|
||||
}
|
||||
transforms = append(transforms, TopK(topK))
|
||||
if minP < 0.0 {
|
||||
minP = 0.0
|
||||
}
|
||||
if minP >= 1.0 {
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
if topP != 0 {
|
||||
if topP < 0 || topP >= 1 {
|
||||
return nil, errors.New("topP must be between 0 and 1")
|
||||
}
|
||||
transforms = append(transforms, TopP(topP))
|
||||
return &weighted{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
}
|
||||
|
||||
if minP != 0 {
|
||||
if minP < 0 || minP >= 1 {
|
||||
return nil, errors.New("minP must be between 0 and 1")
|
||||
}
|
||||
transforms = append(transforms, MinP(minP))
|
||||
}
|
||||
|
||||
if seed >= 0 {
|
||||
seed64 := uint64(seed)
|
||||
return Weighted(&seed64, transforms...), nil
|
||||
}
|
||||
return Weighted(nil, transforms...), nil
|
||||
}
|
||||
|
||||
104
sample/samplers_benchmark_test.go
Normal file
104
sample/samplers_benchmark_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkWeightedSampler(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000, 10000}
|
||||
|
||||
for _, size := range sizes {
|
||||
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
||||
logits := make([]float32, size)
|
||||
for i := range logits {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
b.Fatalf("Sampling failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
configs := []struct {
|
||||
name string
|
||||
temperature float32
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
seed int
|
||||
}{
|
||||
{"Greedy", 0, -1, 0, 0, -1},
|
||||
{"Temperature", 0.8, -1, 0, 0, -1},
|
||||
{"TopK", 0.8, 50, 0, 0, -1},
|
||||
{"TopP", 0.8, -1, 0.9, 0, -1},
|
||||
{"MinP", 0.8, -1, 0, 0.05, -1},
|
||||
{"WithSeed", 0.8, 50, 0, 0, 42},
|
||||
}
|
||||
|
||||
// Fixed size for common vocab size
|
||||
size := 128000
|
||||
logits := make([]float32, size)
|
||||
for i := range logits {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
_, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
b.Fatalf("Sampling failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
_, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
b.Fatalf("Sampling failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGreedySampler(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000, 10000, 100000}
|
||||
|
||||
for _, size := range sizes {
|
||||
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
||||
logits := make([]float32, size)
|
||||
for i := range logits {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, -1)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
_, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
b.Fatalf("Sampling failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -19,64 +18,19 @@ func TestWeighted(t *testing.T) {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err == nil {
|
||||
t.Error("expected error for no valid tokens, got index", got)
|
||||
}
|
||||
|
||||
seed := uint64(42)
|
||||
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// With seed 42, we expect a consistent sample
|
||||
want = int32(3) // This will be deterministic due to the seed
|
||||
want = int32(3) // Should pick highest probability with this r value
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
type testTransform struct {
|
||||
id int
|
||||
callOrder *[]int
|
||||
}
|
||||
|
||||
func (ts *testTransform) Apply(logits []float64) []float64 {
|
||||
if ts.callOrder != nil {
|
||||
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
input := []float32{1, 2, 3, 4}
|
||||
|
||||
var callOrder []int
|
||||
mock1 := &testTransform{
|
||||
id: 1,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock2 := &testTransform{
|
||||
id: 2,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock3 := &testTransform{
|
||||
id: 3,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
|
||||
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -85,75 +39,41 @@ func TestNewSampler(t *testing.T) {
|
||||
topP float32
|
||||
minP float32
|
||||
seed int
|
||||
wantErr bool
|
||||
wantGreedy bool // Instead of wantErr, check if we get greedy sampler
|
||||
}{
|
||||
{
|
||||
name: "no transforms",
|
||||
// temperature is 0, so greedy should be used
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "temperature",
|
||||
temperature: 0.5,
|
||||
wantErr: false,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature negative",
|
||||
temperature: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature too high",
|
||||
temperature: 2.1,
|
||||
wantErr: true,
|
||||
name: "zero temperature - greedy",
|
||||
temperature: 0,
|
||||
wantGreedy: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
temperature: 0.1,
|
||||
topK: 10,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
temperature: 0.1,
|
||||
topP: 0.9,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
temperature: 0.1,
|
||||
minP: 0.2,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
name: "seed - weighted",
|
||||
temperature: 0.1,
|
||||
seed: 42,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
@@ -162,16 +82,16 @@ func TestNewSampler(t *testing.T) {
|
||||
topP: 0.9,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: false,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "all zeroes",
|
||||
name: "all zeroes - greedy",
|
||||
temperature: 0.0,
|
||||
topK: 0,
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: false, // all zeroes means no transforms
|
||||
wantGreedy: true,
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
@@ -180,33 +100,28 @@ func TestNewSampler(t *testing.T) {
|
||||
topP: 0.95,
|
||||
minP: 0.1,
|
||||
seed: 42,
|
||||
wantErr: false,
|
||||
wantGreedy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
||||
sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||
_, isGreedy := sampler.(*greedy)
|
||||
if isGreedy != tt.wantGreedy {
|
||||
t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
transforms := []Transform{
|
||||
Temperature(0.5),
|
||||
TopK(10),
|
||||
TopP(0.9),
|
||||
MinP(0.2),
|
||||
}
|
||||
|
||||
weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(),
|
||||
"Weighted": Weighted(nil, transforms...),
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": weighted,
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
logits := make([]float32, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
@@ -215,7 +130,7 @@ func BenchmarkSample(b *testing.B) {
|
||||
for name, s := range samplers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
if _, err := s.Sample(logits); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
@@ -1,120 +1,203 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
)
|
||||
|
||||
type Transform interface {
|
||||
Apply([]float64) []float64
|
||||
}
|
||||
|
||||
// TODO(parthsareen): potentially cache softmax values
|
||||
func softmax(logits []float64) []float64 {
|
||||
var sum float64
|
||||
probs := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
probs[i] = math.Exp(v)
|
||||
sum += probs[i]
|
||||
func softmax(ts []logit) []logit {
|
||||
var sum float32
|
||||
for i, v := range ts {
|
||||
ts[i].value = float32(math.Exp(float64(v.value)))
|
||||
sum += ts[i].value
|
||||
}
|
||||
|
||||
for i := range probs {
|
||||
probs[i] /= sum
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return probs
|
||||
return ts
|
||||
}
|
||||
|
||||
type Temperature float64
|
||||
func temperature(ti []logit, t float32) []logit {
|
||||
if t == 1 {
|
||||
return ti
|
||||
}
|
||||
|
||||
func (t Temperature) Apply(logits []float64) []float64 {
|
||||
temp := math.Max(float64(t), 1e-7)
|
||||
temp := max(t, 1e-7)
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, token := range ti {
|
||||
if token.value > maxLogit {
|
||||
maxLogit = token.value
|
||||
}
|
||||
}
|
||||
|
||||
// subtracting max logit to avoid under/overflow
|
||||
maxLogit := slices.Max(logits)
|
||||
for i := range logits {
|
||||
logits[i] = (logits[i] - maxLogit) / temp
|
||||
for i := range ti {
|
||||
ti[i].value = (ti[i].value - maxLogit) / temp
|
||||
}
|
||||
|
||||
return logits
|
||||
return ti
|
||||
}
|
||||
|
||||
type logitMap struct {
|
||||
index int
|
||||
logit float64
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
// TODO(parthsareen): avoid having to check all logits after this transform
|
||||
func (k TopK) Apply(logits []float64) []float64 {
|
||||
if int(k) >= len(logits) {
|
||||
return logits
|
||||
}
|
||||
q := pq.NewWith(func(a, b logitMap) int {
|
||||
return -cmp.Compare(a.logit, b.logit)
|
||||
})
|
||||
|
||||
for i, logit := range logits {
|
||||
q.Enqueue(logitMap{index: i, logit: logit})
|
||||
}
|
||||
|
||||
validLogits := make(map[int]float64)
|
||||
for range k {
|
||||
logitMap, _ := q.Dequeue()
|
||||
validLogits[logitMap.index] = logitMap.logit
|
||||
}
|
||||
|
||||
for i := range logits {
|
||||
if _, ok := validLogits[i]; !ok {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
|
||||
return logits
|
||||
}
|
||||
|
||||
type TopP float64
|
||||
|
||||
func (p TopP) Apply(logits []float64) []float64 {
|
||||
probs := softmax(logits)
|
||||
indices := make([]int, len(probs))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
|
||||
// sort in descending order
|
||||
slices.SortFunc(indices, func(i, j int) int {
|
||||
return cmp.Compare(probs[j], probs[i])
|
||||
})
|
||||
|
||||
var sum float64
|
||||
for i, idx := range indices {
|
||||
sum += probs[idx]
|
||||
if sum > float64(p) {
|
||||
for _, idx := range indices[i+1:] {
|
||||
logits[idx] = math.Inf(-1)
|
||||
}
|
||||
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
|
||||
//
|
||||
// The heap is represented as an array where for any node at index i:
|
||||
// - Left child is at index 2i + 1
|
||||
// - Right child is at index 2i + 2
|
||||
// - Parent is at index (i-1)/2
|
||||
//
|
||||
// The function compares a node with its children and:
|
||||
// 1. Finds the smallest value between the node and its children
|
||||
// 2. If the node is not the smallest, swaps it with its smallest child
|
||||
// 3. Continues this process down the affected path until the min-heap property is restored
|
||||
func siftDown(data []logit, start, end int) {
|
||||
root := start
|
||||
for {
|
||||
child := 2*root + 1
|
||||
if child >= end {
|
||||
break
|
||||
}
|
||||
// Find smaller child (we want min heap)
|
||||
if child+1 < end && data[child+1].value < data[child].value {
|
||||
child++
|
||||
}
|
||||
// Exit if root is already smaller than children
|
||||
if data[root].value <= data[child].value {
|
||||
break
|
||||
}
|
||||
// Swap with smaller child and continue
|
||||
data[root], data[child] = data[child], data[root]
|
||||
root = child
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type MinP float64
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []logit, k int) []logit {
|
||||
if k >= len(ts) {
|
||||
return ts
|
||||
}
|
||||
// Heapify + siftDown - O(nlog(k))
|
||||
// Build min-heap of first k elements
|
||||
heap := ts[:k]
|
||||
for i := k/2 - 1; i >= 0; i-- {
|
||||
siftDown(heap, i, k)
|
||||
}
|
||||
|
||||
func (p MinP) Apply(logits []float64) []float64 {
|
||||
probs := softmax(logits)
|
||||
threshold := slices.Max(probs) * float64(p)
|
||||
|
||||
for i, prob := range probs {
|
||||
if prob < threshold {
|
||||
logits[i] = math.Inf(-1)
|
||||
// Process remaining elements - if larger than heap root, replace root
|
||||
for i := k; i < len(ts); i++ {
|
||||
if ts[i].value > heap[0].value {
|
||||
heap[0] = ts[i]
|
||||
siftDown(heap, 0, k)
|
||||
}
|
||||
}
|
||||
|
||||
return logits
|
||||
slices.Reverse(heap)
|
||||
|
||||
ts = heap
|
||||
return ts
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
func topP(ts []logit, p float32) []logit {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
|
||||
// Find cutoff index where cumulative sum exceeds p
|
||||
var sum float32
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
func minP(ts []logit, p float32) []logit {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
|
||||
// Conting sort implementation to sort tokens by logits
|
||||
func sortLogits(tokens []logit) {
|
||||
if len(tokens) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Find max/min in a single pass
|
||||
minLogit, maxLogit := tokens[0].value, tokens[0].value
|
||||
for _, t := range tokens[1:] {
|
||||
if t.value < minLogit {
|
||||
minLogit = t.value
|
||||
} else if t.value > maxLogit {
|
||||
maxLogit = t.value
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate scaling to map to uint32 range
|
||||
logitRange := maxLogit - minLogit
|
||||
if logitRange < 1e-6 {
|
||||
return // All values effectively equal
|
||||
}
|
||||
|
||||
// Count frequencies directly from tokens
|
||||
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
|
||||
var counts [256]int // For first byte
|
||||
|
||||
// First pass: count frequencies
|
||||
for _, t := range tokens {
|
||||
// Map to [0, maxInt] range
|
||||
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||
counts[score>>16]++
|
||||
}
|
||||
|
||||
// Calculate offsets
|
||||
var offset int
|
||||
for i := range counts {
|
||||
count := counts[i]
|
||||
counts[i] = offset
|
||||
offset += count
|
||||
}
|
||||
|
||||
// Second pass: place elements in correct position
|
||||
output := make([]logit, len(tokens))
|
||||
// Track current positions
|
||||
countsCopy := counts
|
||||
|
||||
for i, t := range tokens {
|
||||
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||
|
||||
pos := countsCopy[score>>16]
|
||||
countsCopy[score>>16]++
|
||||
output[len(tokens)-1-pos] = tokens[i]
|
||||
}
|
||||
|
||||
copy(tokens, output)
|
||||
}
|
||||
|
||||
@@ -4,77 +4,182 @@ import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
||||
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
// Helper to convert float64 slice to logit slice
|
||||
func toLogits(values []float64) []logit {
|
||||
tokens := make([]logit, len(values))
|
||||
for i, v := range values {
|
||||
tokens[i] = logit{
|
||||
id: int32(i),
|
||||
value: float32(v),
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Helper to compare logit slices
|
||||
func compareLogits(t *testing.T, name string, want []float64, got []logit) {
|
||||
t.Helper()
|
||||
if len(want) != len(got) {
|
||||
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
|
||||
return
|
||||
}
|
||||
for i := range want {
|
||||
if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
|
||||
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float64{2, -1, 4, -3, 1, -2, 0}
|
||||
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
|
||||
|
||||
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
||||
got := temperature(toLogits(input), 0.5)
|
||||
compareLogits(t, "Temperature", want, got)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
got := softmax(toLogits(input))
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
sum += token.value
|
||||
}
|
||||
if math.Abs(float64(sum)-1.0) > 1e-6 {
|
||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
}
|
||||
|
||||
// Check relative ordering is preserved
|
||||
for i := 1; i < len(got); i++ {
|
||||
if got[i].value < got[i-1].value {
|
||||
t.Errorf("probability ordering not preserved at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
|
||||
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
want = []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
// Test k=3
|
||||
got := topK(toLogits(input), 3)
|
||||
if len(got) != 3 {
|
||||
t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
|
||||
}
|
||||
// Should keep highest 3 values: 4, 2, 1
|
||||
want := []float64{4, 2, 1}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
|
||||
// Test k > len
|
||||
got = topK(toLogits(input), 10)
|
||||
compareLogits(t, "topK(10)", input, got)
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
tokens := toLogits(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = temperature(tokens, 1)
|
||||
tokens = softmax(tokens)
|
||||
sortLogits(tokens)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
|
||||
tokens := toLogits(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = temperature(tokens, 1)
|
||||
tokens = softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransform(b *testing.B) {
|
||||
transforms := map[string]Transform{
|
||||
"Temperature": Temperature(0.5),
|
||||
"TopK": TopK(10),
|
||||
"TopP": TopP(0.9),
|
||||
"MinP": MinP(0.2),
|
||||
func TestSortLogits(t *testing.T) {
|
||||
input := []float64{3, 1, 4, 2, -1, 0, -2}
|
||||
tokens := toLogits(input)
|
||||
|
||||
sortLogits(tokens)
|
||||
|
||||
for i := 1; i < len(tokens); i++ {
|
||||
if tokens[i].value > tokens[i-1].value {
|
||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
||||
i, tokens[i].value, tokens[i-1].value)
|
||||
}
|
||||
}
|
||||
|
||||
logits := make([]float64, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float64()
|
||||
}
|
||||
|
||||
for name, transform := range transforms {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
transform.Apply(logits)
|
||||
}
|
||||
})
|
||||
}
|
||||
want := []float64{4, 3, 2, 1, 0, -1, -2}
|
||||
compareLogits(t, "sortLogits", want, tokens)
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
// Generate random logits
|
||||
tokens := make([]logit, 1<<16)
|
||||
for i := range tokens {
|
||||
tokens[i] = logit{
|
||||
id: int32(i),
|
||||
value: rand.Float32(),
|
||||
}
|
||||
}
|
||||
|
||||
tokensCopy := make([]logit, len(tokens))
|
||||
|
||||
b.Run("Temperature", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
temperature(tokensCopy, 0.5)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("TopK", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("TopP", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MinP", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SortTokens", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
sortLogits(tokensCopy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,9 +45,9 @@ import (
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrManifestNotFound is returned when a manifest is not found in the
|
||||
// ErrModelNotFound is returned when a manifest is not found in the
|
||||
// cache or registry.
|
||||
ErrManifestNotFound = errors.New("manifest not found")
|
||||
ErrModelNotFound = errors.New("model not found")
|
||||
|
||||
// ErrManifestInvalid is returned when a manifest found in a local or
|
||||
// remote cache is invalid.
|
||||
@@ -114,7 +114,18 @@ type Error struct {
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
|
||||
var b strings.Builder
|
||||
b.WriteString("registry responded with status ")
|
||||
b.WriteString(strconv.Itoa(e.Status))
|
||||
if e.Code != "" {
|
||||
b.WriteString(": code ")
|
||||
b.WriteString(e.Code)
|
||||
}
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
@@ -355,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
res, err := r.doOK(ctx, "POST", startURL, nil)
|
||||
res, err := r.send(ctx, "POST", startURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -379,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
}
|
||||
req.ContentLength = l.Size
|
||||
|
||||
res, err = doOK(r.client(), req)
|
||||
res, err = sendRequest(r.client(), req)
|
||||
if err == nil {
|
||||
res.Body.Close()
|
||||
}
|
||||
@@ -399,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
n.Model(),
|
||||
n.Tag(),
|
||||
)
|
||||
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
|
||||
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
|
||||
if err == nil {
|
||||
res.Body.Close()
|
||||
}
|
||||
@@ -448,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
var g errgroup.Group
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
|
||||
for _, l := range m.Layers {
|
||||
layers := m.Layers
|
||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||
layers = append(layers, m.Config)
|
||||
}
|
||||
|
||||
for _, l := range layers {
|
||||
if exists(l) {
|
||||
t.update(l, l.Size, ErrCached)
|
||||
continue
|
||||
@@ -468,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
|
||||
if l.Size <= r.maxChunkingThreshold() {
|
||||
g.Go(func() error {
|
||||
res, err := doOK(r.client(), req)
|
||||
// TODO(bmizerany): retry/backoff like below in
|
||||
// the chunking case
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -494,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
// fire an initial request to get the final URL and
|
||||
// then use that URL for the chunk requests.
|
||||
req.Header.Set("Range", "bytes=0-0")
|
||||
res, err := doOK(r.client(), req)
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Body.Close()
|
||||
req = res.Request.WithContext(req.Context())
|
||||
|
||||
streamNo := 0
|
||||
tws := make([]*bufio.Writer, r.maxStreams()-1)
|
||||
wp := writerPool{size: r.maxChunkSize()}
|
||||
|
||||
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
ticket := q.Take()
|
||||
bufIdx := streamNo % len(tws)
|
||||
streamNo++
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -520,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := func() error {
|
||||
req := req.Clone(req.Context())
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := doOK(r.client(), req)
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
tw := tws[bufIdx]
|
||||
if tw == nil {
|
||||
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
|
||||
tws[bufIdx] = tw
|
||||
}
|
||||
tw := wp.get()
|
||||
tw.Reset(ticket)
|
||||
defer tw.Reset(nil) // release ticket
|
||||
defer wp.put(tw)
|
||||
|
||||
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||
if err != nil {
|
||||
@@ -595,6 +610,9 @@ type Manifest struct {
|
||||
Name string `json:"-"` // the canonical name of the model
|
||||
Data []byte `json:"-"` // the raw data of the manifest
|
||||
Layers []*Layer `json:"layers"`
|
||||
|
||||
// For legacy reasons, we still have to download the config layer.
|
||||
Config *Layer `json:"config"`
|
||||
}
|
||||
|
||||
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||
@@ -678,7 +696,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||
data, err := os.ReadFile(c.GetFile(d))
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
|
||||
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -701,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
||||
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
|
||||
}
|
||||
|
||||
res, err := r.doOK(ctx, "GET", manifestURL, nil)
|
||||
res, err := r.send(ctx, "GET", manifestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -726,7 +744,7 @@ func (r *Registry) client() *http.Client {
|
||||
}
|
||||
|
||||
// newRequest constructs a new request, ready to use, with the given method,
|
||||
// url, and body, presigned with client Key and UserAgent.
|
||||
// url, and body, pre-signed with client [Key] and [UserAgent].
|
||||
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
@@ -745,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// doOK makes a request with the given client and request, and returns the
|
||||
// sendRequest makes a request with the given client and request, and returns the
|
||||
// response if the status code is 200. If the status code is not 200, an Error
|
||||
// is parsed from the response body and returned. If any other error occurs, it
|
||||
// is returned.
|
||||
func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
|
||||
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("request error %s: %w", r.URL, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if r.URL.Scheme == "https+insecure" {
|
||||
// TODO(bmizerany): clone client.Transport, set
|
||||
// InsecureSkipVerify, etc.
|
||||
@@ -792,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
|
||||
// Use the raw body if we can't parse it as an error object.
|
||||
re.Message = string(out)
|
||||
}
|
||||
|
||||
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
|
||||
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
|
||||
return nil, ErrModelNotFound
|
||||
}
|
||||
|
||||
re.Status = res.StatusCode
|
||||
return nil, &re
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// doOK is a convenience method for making a request with newRequest and
|
||||
// passing it to doOK with r.client().
|
||||
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
// send is a convenience method for making a request with newRequest and
|
||||
// passing it to send with r.client().
|
||||
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
req, err := r.newRequest(ctx, method, path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doOK(r.client(), req)
|
||||
return sendRequest(r.client(), req)
|
||||
}
|
||||
|
||||
// makeAuthToken creates an Ollama auth token for the given private key.
|
||||
@@ -960,3 +990,28 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
type writerPool struct {
|
||||
size int64 // set by the caller
|
||||
|
||||
mu sync.Mutex
|
||||
ws []*bufio.Writer
|
||||
}
|
||||
|
||||
func (p *writerPool) get() *bufio.Writer {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.ws) == 0 {
|
||||
return bufio.NewWriterSize(nil, int(p.size))
|
||||
}
|
||||
w := p.ws[len(p.ws)-1]
|
||||
p.ws = p.ws[:len(p.ws)-1]
|
||||
return w
|
||||
}
|
||||
|
||||
func (p *writerPool) put(w *bufio.Writer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
w.Reset(nil)
|
||||
p.ws = append(p.ws, w)
|
||||
}
|
||||
|
||||
@@ -608,7 +608,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
_, err := rc.Resolve(t.Context(), url)
|
||||
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
||||
t.Errorf("err = %v; want cert verifiction failure", err)
|
||||
t.Errorf("err = %v; want cert verification failure", err)
|
||||
}
|
||||
|
||||
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
||||
|
||||
@@ -13,9 +13,13 @@ type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
//
|
||||
// It is called once at the beginning of the download with a zero n and
|
||||
// then once per read operation with the number of bytes read so far,
|
||||
// and an error if any.
|
||||
// The n argument is the number of bytes transferred so far, and err is
|
||||
// any error that has occurred. If n == 0, and err is nil, the download
|
||||
// or upload has just started. If err is [ErrCached], the download or
|
||||
// upload has been skipped because the blob is already present in the
|
||||
// local cache or remote registry, respectively. Otherwise, if err is
|
||||
// non-nil, the download or upload has failed. When l.Size == n, and
|
||||
// err is nil, the download or upload has completed.
|
||||
//
|
||||
// A function assigned must be safe for concurrent use. The function is
|
||||
// called synchronously and so should not block or take long to run.
|
||||
|
||||
@@ -7,10 +7,14 @@ import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
@@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
return false, s.handleDelete(rec, r)
|
||||
case "/api/pull":
|
||||
return false, s.handlePull(rec, r)
|
||||
default:
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
@@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return s.Prune()
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Status string `json:"status"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
}
|
||||
|
||||
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maybeFlush := func() {
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
defer maybeFlush()
|
||||
|
||||
var mu sync.Mutex
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// TODO(bmizerany): coalesce these updates; writing per
|
||||
// update is expensive
|
||||
enc.Encode(progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Status: "pulling",
|
||||
Total: l.Size,
|
||||
Completed: n,
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// TODO(bmizerany): continue to support non-streaming responses
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
}()
|
||||
|
||||
func() {
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
mu.Lock()
|
||||
maybeFlush()
|
||||
mu.Unlock()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// These final updates are not strictly necessary, because they have
|
||||
// already happened at this point. Our pull handler code used to do
|
||||
// these steps after, not during, the pull, and they were slow, so we
|
||||
// wanted to provide feedback to users what was happening. For now, we
|
||||
// keep them to not jar users who are used to seeing them. We can phase
|
||||
// them out with a new and nicer UX later. One without progress bars
|
||||
// and digests that no one cares about.
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
"golang.org/x/tools/txtar"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type panicTransport struct{}
|
||||
@@ -30,7 +40,7 @@ type bytesResetter interface {
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *Local {
|
||||
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||
@@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := panicOnRoundTrip
|
||||
if upstreamRegistry != nil {
|
||||
s := httptest.NewTLSServer(upstreamRegistry)
|
||||
t.Cleanup(s.Close)
|
||||
tr := s.Client().Transport.(*http.Transport).Clone()
|
||||
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
|
||||
}
|
||||
client = &http.Client{Transport: tr}
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
HTTPClient: client,
|
||||
Mask: "example.com/library/_:latest",
|
||||
}
|
||||
|
||||
l := &Local{
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
@@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
||||
func TestServerDelete(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
s := newTestServer(t)
|
||||
s := newTestServer(t, nil)
|
||||
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
@@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//go:embed testdata/registry.txt
|
||||
var registryTXT []byte
|
||||
|
||||
var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// Txtar gets hung up on \r\n line endings, so we need to convert them
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fmt.Printf("%q\n", a.Comment)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fsys
|
||||
})
|
||||
|
||||
func TestServerPull(t *testing.T) {
|
||||
modelsHandler := http.FileServerFS(registryFS())
|
||||
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v2/library/BOOM/manifests/latest":
|
||||
w.WriteHeader(999)
|
||||
io.WriteString(w, `{"error": "boom"}`)
|
||||
case "/v2/library/unknown/manifests/latest":
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving file: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
for want := range strings.Lines(wantlines) {
|
||||
want = strings.TrimSpace(want)
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
{"status":"verifying layers"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: invalid or missing name: \"\""}
|
||||
|
||||
!verifying
|
||||
!writing
|
||||
!success
|
||||
`)
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t)
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
}
|
||||
|
||||
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-- v2/library/smol/manifests/latest --
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {
|
||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
||||
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
|
||||
"size": 3
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.model",
|
||||
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
|
||||
"size": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
|
||||
GGUF
|
||||
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
|
||||
{}
|
||||
@@ -42,6 +42,12 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
|
||||
var useClient2 = experimentEnabled("client2")
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
@@ -1173,6 +1179,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
r.GET("/api/tags", s.ListHandler)
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
|
||||
// Create
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
@@ -1194,16 +1201,19 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
Prune: PruneLayers,
|
||||
}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
@@ -1258,15 +1268,20 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
rc, err = ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.Handle("/", h)
|
||||
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
|
||||
Reference in New Issue
Block a user