Compare commits

..

31 Commits

Author SHA1 Message Date
likelovewant
17bb5ea679 Merge branch 'ollama:main' into main 2025-03-23 12:10:05 +08:00
Blake Mizerany
ce929984a3 server/internal/client/ollama: fix file descriptor management in Pull (#9931)
Close chunked writers as soon as downloads complete, rather than
deferring closure until Pull exits. This prevents exhausting file
descriptors when pulling many layers.

Instead of unbounded defers, use a WaitGroup and background goroutine
to close each chunked writer as soon as its downloads finish.

Also rename 'total' to 'received' for clarity.
2025-03-21 16:16:38 -07:00
Michael Yang
4b34930a31 Merge pull request #9897 from ollama/mxyng/chunk-load
ml/backend/ggml: load tensors in 128KiB chunks
2025-03-21 14:47:13 -07:00
Michael Yang
74bd09652d ml/backend/ggml: load tensors in 32KiB chunks 2025-03-21 14:43:52 -07:00
Bruce MacDonald
fb6252d786 benchmark: performance of running ollama server (#8643) 2025-03-21 13:08:20 -07:00
Blake Mizerany
c794fef2f2 server/internal/client/ollama: persist through chunk download errors (#9923) 2025-03-21 13:03:43 -07:00
Parth Sareen
00ebda8cc4 Revert "parser: remove role validation from Modelfile parser" (#9917)
This reverts commit ffbfe833da.
2025-03-21 12:38:09 -07:00
Parth Sareen
d14ce75b95 docs: update final response for /api/chat stream (#9919) 2025-03-21 12:35:47 -07:00
Jesse Gross
2d6eac9084 kvcache: Optimize sliding window attention
Currently sliding window attention allocates and uses the full
context size and just masks out any tokens that are outside of the
window. However, we really only need (roughly) the sliding window
size.

At large context sizes this improves two things:
 - Memory allocated - since the fully context size is allocated up front,
   memory requirements drop substantially. On Gemma3:4b with a 32k
   context window, total memory usage (including weights and non-sliding
   layers) drops from ~20GB to ~8GB.
 - Computation - ranges that are completely outside of the sliding
   window are now removed from the tensors that are returned from the
   cache rather than simply being masked out. This results in more
   efficient processing, scaling with the size of the context that
   has actually been used.

Notable, this does not update the scheduler for any model to be aware of
the smaller memory requirements. This is difficult for Gemma3 because
the layers are heterogeneous between sliding and non-sliding attention.
As a result, while actual memory consumption will be reduced, the
scheduler will over-estimate the requirements of the model. This means
that splitting between GPUs or GPUs and CPUs will still be suboptimal.

Bug #9730
2025-03-21 11:20:19 -07:00
Jesse Gross
3ed7ad3ab3 kvcache: Pass granular cache size into implementations
Currently the runner computes the kv size needed and creates a
cache of that size. This is the context size times number of
parallel sequences.

Cache implementations can make better decisions about their memory
usage, so instead pass in the required capacity, number of sequences
and maximum batch size. For now, the causal cache just uses this to
compute the size in the same way as before.
2025-03-21 11:20:19 -07:00
Patrick Devine
6d1103048e fix: show correct bool value for kv in verbose show information (#9928) 2025-03-21 11:13:54 -07:00
Jesse Gross
0ff28758b3 ollamarunner: Provide mechanism for backends to report loading progress
This enables the runner to report progress back to the Ollama server,
both for showing status to the user and also to prevent the server
from killing the runner if it thinks things have stalled.

Most of the infrastructure was already there, this extends it to
be available to the backends.
2025-03-21 10:44:26 -07:00
Jesse Gross
d3e9ca3eda kvcache: Account for source tensors in defrag operation count
Defragging the KV cache can generate a lot of operations, so we
need to be careful that we don't overflow the number that the graph
can support. We currently account for all of the nodes that we add
to the graph for each move but we also need to include the original
cache tensors as well.

Fixes #9904
2025-03-21 10:42:19 -07:00
Jesse Gross
0fbfcf3c9c model: Pass input tensor instead of raw data to models
Rather than directly giving the input data to models, we can
pass a tensor instead. In the short term, this saves some duplicated
code.

Longer term, we will want to overlap setting up the next batch with
processing of the current one. In this case, we will only have the
shape of tensor but it will not be loaded with data at the time of
graph generation. By passing only a tensor to models now, we set up
this possibility and prevent them from relying on data that they won't
have in the future.

Although the same could be done for Positions and Outputs, in some
cases we either need the raw input data or don't use them at all.
Therefore, for now we leave them as they are and allow models to
convert them to tensors as needed.
2025-03-20 13:28:13 -07:00
Jesse Gross
0c220935bd input: Rename Options to Batch
Options is no longer very descriptive of this struct.
2025-03-20 13:28:13 -07:00
rylativity
ffbfe833da parser: remove role validation from Modelfile parser (#9874)
* updates parser/parser.go to allow arbitrary roles in Modelfile MESSAGE blocks
2025-03-20 13:11:17 -07:00
Parth Sareen
42a14f7f63 sample: add error handling for empty logits (#9740) 2025-03-20 11:11:18 -07:00
Patrick Devine
f8c3dbe5b5 templates: add autotemplate for gemma3 (#9880)
This change allows the gemma3 template to be autodetected during `ollama
create`.
2025-03-20 00:15:30 -07:00
Jesse Gross
b078dd157c gemma2: Remove second call to Rows
Looks like a merge conflict that broke the model.
2025-03-19 17:28:49 -07:00
Blake Mizerany
2ddacd7516 server/internal/client/ollama: confirm all chunksums were received (#9893)
If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
2025-03-19 14:59:57 -07:00
Jeffrey Morgan
da0e345200 ml: use input context for extracting outputs (#9875) 2025-03-18 18:08:19 -07:00
Bruce MacDonald
df94175a0f ggml: return error on failure to read tensor data (#9872)
When converting a ggml model if there is a failure to read tensor data a nil error value was being returned. It should be assigned to the actual error from reading.
2025-03-18 16:51:33 -07:00
Bruce MacDonald
61a8825216 convert: return name of unsupported architecture (#9862)
When a model's architecture cannot be converted return the name of the unsupported arch in the error message.
2025-03-18 10:38:28 -07:00
likelovewant
a69a1e6e63 Merge remote-tracking branch 'upstream/main' 2025-03-18 18:09:35 +08:00
Michael Yang
021dcf089d Merge pull request #9824 from ollama/mxyng/sched
conditionally enable parallel pipelines
2025-03-17 15:41:37 -07:00
Jesse Gross
bf24498b1e ollamarunner: Check for minBatch of context space when shifting
Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.
2025-03-17 15:33:16 -07:00
Bruce MacDonald
95e271d98f runner: remove cache prompt flag from ollama runner (#9826)
We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
2025-03-17 15:11:15 -07:00
Jeffrey Morgan
364629b8d6 ml/backend/ggml: allocate memory with malloc when loading model (#9822) 2025-03-17 13:32:40 -07:00
Parth Sareen
108fe02165 sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
2025-03-17 11:24:18 -07:00
Michael Yang
4561fff36e conditionally enable parallel pipelines 2025-03-17 09:46:07 -07:00
Daniel Hiltgen
50b5962042 Add support for ROCm gfx1151 (#9773) 2025-03-17 09:33:57 -07:00
40 changed files with 1111 additions and 357 deletions

View File

@@ -98,7 +98,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED) find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS) if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011|1012(:xnack-)|103[0-6]|110[0-3]|1150)$") list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|1201)$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif() endif()

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1011;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1012:xnack-;" "AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
} }
} }
], ],

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -703,6 +703,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
for _, k := range keys { for _, k := range keys {
var v string var v string
switch vData := resp.ModelInfo[k].(type) { switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string: case string:
v = vData v = vData
case float64: case float64:

View File

@@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
ModelInfo: map[string]any{ ModelInfo: map[string]any{
"general.architecture": "test", "general.architecture": "test",
"general.parameter_count": float64(8_000_000_000), "general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000), "test.context_length": float64(1000),
"test.embedding_length": float64(11434), "test.embedding_length": float64(11434),
}, },
@@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
Metadata Metadata
general.architecture test general.architecture test
general.parameter_count 8e+09 general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000 test.context_length 1000
test.embedding_length 11434 test.embedding_length 11434

View File

@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM": case "CohereForCausalLM":
conv = &commandrModel{} conv = &commandrModel{}
default: default:
return errors.New("unsupported architecture") return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
} }
if err := json.Unmarshal(bts, conv); err != nil { if err := json.Unmarshal(bts, conv); err != nil {

View File

@@ -558,6 +558,10 @@ Final response:
{ {
"model": "llama3.2", "model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true, "done": true,
"total_duration": 4883583458, "total_duration": 4883583458,
"load_duration": 1334875, "load_duration": 1334875,

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -43,8 +43,13 @@ type Cache interface {
// ** cache management ** // ** cache management **
// Init sets up runtime parameters // Init sets up runtime parameters.
Init(backend ml.Backend, dtype ml.DType, capacity int32) // backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it // Close closes the cache and frees resources associated with it
Close() Close()
@@ -52,7 +57,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. // entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size // The mask is of shape history size, batch size
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
Capacity int32
windowSize int32 windowSize int32
opts CausalOptions opts CausalOptions
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok { if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.config.MaskDType = ml.DTypeF32 c.config.MaskDType = ml.DTypeF32
} }
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
} }
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
} }
} }
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(opts.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = opts.Sequences c.curSequences = batch.Sequences
c.curPositions = opts.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
c.updateSlidingWindow()
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) { if errors.Is(err, ErrKvCacheFull) {
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
} }
c.curCellRange = newRange() c.curCellRange = newRange()
for i, pos := range opts.Positions { for i, pos := range batch.Positions {
seq := opts.Sequences[i] seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
} }
} }
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
} }
func roundDown(length, pad int) int { func roundDown(length, pad int) int {
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys { for i, key := range c.keys {
if key == nil { if key == nil {
continue continue
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
numKVHeads := key.Dim(1) numKVHeads := key.Dim(1)
rowSize := key.Stride(2) rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i] value := c.values[i]
var vSrcView, vDstView ml.Tensor var vSrcView, vDstView ml.Tensor
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
vHeadDim := value.Dim(1) vHeadDim := value.Dim(1)
elemSize := value.Stride(0) elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else { } else {
vHeadDim := value.Dim(0) vHeadDim := value.Dim(0)
rowSize := value.Stride(2) rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
} }
ctx.Forward( ctx.Forward(
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext() ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a // For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). // copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0 layers := 0
for _, key := range c.keys { for _, key := range c.keys {
if key == nil { if key == nil {
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
layers++ layers++
} }
maxMoves := ctx.MaxGraphNodes() / (6 * layers) maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0 moves := 0
var pendingSrc, pendingDst, pendingLen int var pendingSrc, pendingDst, pendingLen int
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
} }
if _, ok := c.keys[c.curLayer]; !ok { if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
} }
if _, ok := c.values[c.curLayer]; !ok { if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV { if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else { } else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
} }
} }
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0) elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3) value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else { } else {
rowSize := c.values[c.curLayer].Stride(2) rowSize := c.values[c.curLayer].Stride(2)

View File

@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF32, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
name: "SlidingWindow", name: "FirstBatch",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4}, inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0}, seqs: []int{0, 0, 0, 0},
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
}, },
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
} }
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
} }
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok { if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
c.config = &config c.config = &config
} }
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 { if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
} }
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
} }
} }
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image // We work with the most recent image
if len(opts.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
} }
return nil return nil

View File

@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
} }
} }
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Init(backend, dtype, capacity) cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
} }
} }
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {
for k := range opts.Positions { for k := range batch.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
} }
} }
return err return err

View File

@@ -2,6 +2,7 @@ package ml
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
@@ -60,6 +61,10 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models // BackendParams controls how the backend loads and executes models
type BackendParams struct { type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU // NumThreads sets the number of threads to use if running on the CPU
NumThreads int NumThreads int
@@ -76,9 +81,9 @@ type BackendParams struct {
FlashAttention bool FlashAttention bool
} }
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) { func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok { if _, ok := backends[name]; ok {
panic("backend: backend already registered") panic("backend: backend already registered")
} }
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f backends[name] = f
} }
func NewBackend(f *os.File, params BackendParams) (Backend, error) { func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok { if backend, ok := backends["ggml"]; ok {
return backend(f, params) return backend(ctx, f, params)
} }
return nil, fmt.Errorf("unsupported backend") return nil, fmt.Errorf("unsupported backend")

View File

@@ -9,15 +9,17 @@ package ggml
import "C" import "C"
import ( import (
"errors" "context"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps" "maps"
"os" "os"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"unicode" "unicode"
"unsafe" "unsafe"
@@ -58,7 +60,7 @@ type Backend struct {
maxGraphNodes int maxGraphNodes int
} }
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1) meta, n, err := fs.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
} }
} }
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads var doneBytes atomic.Uint64
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) totalBytes := uint64(n) - meta.Tensors().Offset
var g errgroup.Group
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] { g.Go(func() error {
g.Go(func() error { tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" { if target == "" {
target = t.Name target = t.Name
} }
@@ -312,23 +318,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name) return fmt.Errorf("unassigned tensor: %s", t.Name)
} }
bts := make([]byte, t.Size()) tts[i] = tt
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) }
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil { if err != nil {
return err return err
} }
if n != len(bts) { for _, tt := range tts {
return errors.New("short read") C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
} }
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size())) s += uint64(n)
return nil
}) if params.Progress != nil {
} done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
} }
if g.Wait() != nil { // start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
return nil, err return nil, err
} }
@@ -371,7 +398,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)), C.int(len(schedBackends)),
C.size_t(maxGraphNodes), C.size_t(maxGraphNodes),
true, C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
), ),
input: deviceBufferTypes[input.d], input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d], output: deviceBufferTypes[output.d],

View File

@@ -1,5 +1,7 @@
package input package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream // Input represents one token in the input stream
type Input struct { type Input struct {
// Token is a single element of text. // Token is a single element of text.
@@ -33,11 +35,24 @@ type MultimodalIndex struct {
Multimodal any Multimodal any
} }
// Options contains the inputs for a model forward pass // Batch contains the inputs for a model forward pass
type Options struct { type Batch struct {
Inputs []int32 // Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex Multimodal []MultimodalIndex
Positions []int32
Sequences []int // Positions is the position for each Input, relative to its sequence. Equal
Outputs []int32 // in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
} }

View File

@@ -1,6 +1,7 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error) Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend Backend() ml.Backend
Config() config Config() config
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
} }
// New initializes a new model instance with the provided configuration based on the metadata in the model file // New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) { func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath) r, err := os.Open(modelPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
b, err := ml.NewBackend(r, params) b, err := ml.NewBackend(ctx, r, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) { if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
} }
if len(opts.Positions) < 1 { if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
t, err := m.Forward(ctx, opts) t, err := m.Forward(ctx, batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{} type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented") panic("unimplemented")
} }

View File

@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount { if len(m.Layers) == gemma27BLayerCount {
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
// final logit softcap // final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap)) hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx) hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)) return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
return hiddenState.Rows(ctx, outputs), nil
} }
func init() { func init() {

View File

@@ -139,23 +139,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil return result, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
} }
func init() { func init() {

View File

@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings // set image embeddings
var except []int var except []int
for _, image := range opts.Multimodal { for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor) visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

View File

@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers { for i, layer := range m.Layers {
m.Cache.SetLayer(i) m.Cache.SetLayer(i)

View File

@@ -135,32 +135,27 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil return inputs, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 { if len(images) > 0 {
crossAttentionStates = images[len(images)-1] crossAttentionStates = images[len(images)-1]
} }
} }
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
} }
func init() { func init() {

View File

@@ -31,8 +31,10 @@ type InputCache struct {
cache kvcache.Cache cache kvcache.Cache
} }
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 { numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
} }
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache cache := model.Config().Cache
if cache != nil { if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
} }
return &InputCache{ return &InputCache{
numCtx: kvSize / int32(numSlots), numCtx: numCtx,
enabled: cache != nil, enabled: cache != nil,
slots: slots, slots: slots,
multiUserCache: multiUserCache, multiUserCache: multiUserCache,
@@ -89,7 +91,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@@ -107,11 +109,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return nil, nil, err return nil, nil, err
} }
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
if !cachePrompt {
numPast = 0
}
slot.InUse = true slot.InUse = true
slot.lastUsed = time.Now() slot.lastUsed = time.Now()

View File

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
}) })
} }
} }
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs)) params.numKeep = int32(len(inputs))
} }
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift // Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1) params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -345,7 +348,8 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
var options input.Options var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
@@ -366,17 +370,6 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize batchSize := s.batchSize
for j, inp := range seq.inputs { for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
}
// If we are required to put following inputs into a single batch then extend the // If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this // batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs. // will cause a break if we have pending inputs.
@@ -389,17 +382,31 @@ func (s *Server) processBatch() error {
break break
} }
options.Inputs = append(options.Inputs, inp.Token) // If the sum of our working set (already processed tokens, tokens we added to this
if inp.Multimodal != nil { // batch, required following tokens) exceeds the context size, then trigger a shift
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) // now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} }
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batchInputs = append(batchInputs, inp.Token)
options.Sequences = append(options.Sequences, seq.cache.Id) if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
}
seq.iBatch = len(options.Outputs) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) { if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
} }
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
} }
@@ -407,14 +414,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):] seq.inputs = seq.inputs[len(seq.pendingInputs):]
} }
if len(options.Inputs) == 0 { if len(batchInputs) == 0 {
return nil return nil
} }
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, options) modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) return fmt.Errorf("failed to decode batch: %w", err)
} }
@@ -454,7 +461,7 @@ func (s *Server) processBatch() error {
} }
// sample a token // sample a token
vocabSize := len(logits) / len(options.Outputs) vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil { if err != nil {
@@ -590,7 +597,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -671,6 +678,7 @@ func (m *multiLPath) String() string {
} }
func (s *Server) loadModel( func (s *Server) loadModel(
ctx context.Context,
mpath string, mpath string,
params ml.BackendParams, params ml.BackendParams,
lpath multiLPath, lpath multiLPath,
@@ -680,7 +688,7 @@ func (s *Server) loadModel(
multiUserCache bool, multiUserCache bool,
) { ) {
var err error var err error
s.model, err = model.New(mpath, params) s.model, err = model.New(ctx, mpath, params)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -692,7 +700,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented") panic("loras are not yet implemented")
} }
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -776,6 +784,9 @@ func Execute(args []string) error {
} }
params := ml.BackendParams{ params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads, NumThreads: *threads,
NumGPULayers: *numGPULayers, NumGPULayers: *numGPULayers,
MainGPU: *mainGPU, MainGPU: *mainGPU,
@@ -784,13 +795,13 @@ func Execute(args []string) error {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx) go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port) addr := "127.0.0.1:" + strconv.Itoa(*port)

View File

@@ -26,6 +26,10 @@ type Sampler struct {
} }
func (s *Sampler) Sample(logits []float32) (int32, error) { func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
tokens := make([]token, len(logits)) tokens := make([]token, len(logits))
for i := range logits { for i := range logits {
tokens[i].id = int32(i) tokens[i].id = int32(i)
@@ -87,19 +91,13 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits // topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK) tokens = topK(tokens, s.topK)
tokens = temperature(tokens, s.temperature) // scale and normalize the tokens in place
tokens = softmax(tokens) temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if len(tokens) == 0 {
return token{}, errors.New("no tokens to sample from")
}
var r float32 var r float32
if s.rng != nil { if s.rng != nil {
r = s.rng.Float32() r = s.rng.Float32()
@@ -122,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return 1 return 1
}) })
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil return tokens[idx], nil
} }

View File

@@ -1,6 +1,7 @@
package sample package sample
import ( import (
"math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
) )
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
if want != got { if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got) t.Errorf("index mismatch: want %d, got %d", want, got)
} }
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
} }
func BenchmarkSample(b *testing.B) { func BenchmarkSample(b *testing.B) {

View File

@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
} }
// temperature applies scaling to the logits // temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token { func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability // Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7) temp = max(temp, 1e-7)
for i := range ts { for i := range ts {
ts[i].value = ts[i].value / temp ts[i].value = ts[i].value / temp
} }
return ts
} }
// softmax applies normalization to the logits // softmax applies normalization to the logits
func softmax(ts []token) []token { func softmax(ts []token) {
// Find max logit for numerical stability // Find max logit for numerical stability
maxLogit := float32(math.Inf(-1)) maxLogit := float32(math.Inf(-1))
for _, t := range ts { for _, t := range ts {
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
for i := range ts { for i := range ts {
ts[i].value /= sum ts[i].value /= sum
} }
return ts
} }
// topK limits the number of tokens considered to the k highest logits // topK limits the number of tokens considered to the k highest logits
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
} }
// topP limits tokens to those with cumulative probability p // topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token { func topP(ts []token, p float32) []token {
if p == 1.0 { if p == 1.0 {
return ts return ts
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
for i, t := range ts { for i, t := range ts {
sum += t.value sum += t.value
if sum > float32(p) { if sum > float32(p) {
ts = ts[:i+1] return ts[:i+1]
return ts
} }
} }
return ts return ts
} }
// minP limits tokens to those with cumulative probability p // minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token { func minP(ts []token, p float32) []token {
if p == 1.0 { maxProb := ts[0].value
return ts
}
maxProb := float32(math.Inf(-1)) threshold := maxProb * p
for _, token := range ts {
if token.value > maxProb { for i, t := range ts {
maxProb = token.value if t.value < threshold {
return ts[:i]
} }
} }
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts return ts
} }

View File

@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0} input := []float32{1.0, 4.0, -2.0, 0.0}
got := temperature(toTokens(input), 0.5) tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0} want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got) compareLogits(t, "temperature(0.5)", want, tokens)
got = temperature(toTokens(input), 1.0) input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0} want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got) compareLogits(t, "temperature(1)", want, tokens)
got = temperature(toTokens(input), 0.0) input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0} want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got) compareLogits(t, "temperature(0)", want, tokens)
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := softmax(toTokens(tt.input)) tokens := toTokens(tt.input)
softmax(tokens)
if tt.expected != nil { if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got) compareLogits(t, tt.name, tt.expected, tokens)
return return
} }
// Check probabilities sum to 1 // Check probabilities sum to 1
var sum float32 var sum float32
for _, token := range got { for _, token := range tokens {
sum += token.value sum += token.value
if token.value < 0 || token.value > 1 { if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value) t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test k=5 tokens = topK(tokens, 5)
got := topK(toTokens(input), 5) if len(tokens) != 5 {
if len(got) != 5 { t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
} }
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154} want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got) compareLogits(t, "topK(3)", want, tokens)
got = topK(toTokens(input), 20) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, 20)
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
} }
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, -1)
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
} }
compareLogits(t, "topK(-1)", want, got) compareLogits(t, "topK(-1)", want, tokens)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, 0)
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
} }
compareLogits(t, "topK(-1)", want, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
@@ -153,50 +165,134 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax to get probabilities // First apply temperature and softmax to get probabilities
tokens = softmax(tokens) softmax(tokens)
tokens = topK(tokens, 20) tokens = topK(tokens, 20)
// Then apply topP // Test with very high p value
got := topP(tokens, 0.95) got := topP(tokens, 1.0)
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 { if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Logf("got: %v", got) t.Logf("got: %v", got)
} }
} }
func TestMinP(t *testing.T) { func TestMinP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3} input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax // First apply temperature and softmax
tokens = softmax(tokens) tokens = topK(tokens, 20)
softmax(tokens)
// Then apply minP tokens = minP(tokens, 1.0)
got := minP(tokens, 0.2)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob // Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 { if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
} }
}
func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20) tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
for i := 1; i < len(tokens); i++ { // Should keep only the highest probability token
if tokens[i].value > tokens[i-1].value { if len(tokens) != len(input) {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f", t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
i, tokens[i].value, tokens[i-1].value) t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
got := minP(tokens, 1.0)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
// Test with normal p value
got = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
} }
} }
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
} }
func BenchmarkTransforms(b *testing.B) { func BenchmarkTransforms(b *testing.B) {
@@ -231,7 +327,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topK(tokensCopy, 10) tokens = topK(tokensCopy, 10)
} }
}) })
@@ -239,7 +335,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topP(tokensCopy, 0.9) tokens = topP(tokensCopy, 0.9)
} }
}) })
@@ -247,7 +343,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
minP(tokensCopy, 0.2) tokens = minP(tokensCopy, 0.2)
} }
}) })
@@ -255,7 +351,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topK(tokensCopy, 200000) tokens = topK(tokensCopy, 200000)
} }
}) })
} }

View File

@@ -37,7 +37,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names" "github.com/ollama/ollama/server/internal/internal/names"
_ "embed" _ "embed"
@@ -60,6 +59,11 @@ var (
// ErrCached is passed to [Trace.PushUpdate] when a layer already // ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push]. // exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached") ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
) )
// Defaults // Defaults
@@ -213,12 +217,6 @@ type Registry struct {
// request. If zero, [DefaultChunkingThreshold] is used. // request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64 ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names // Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used. // to fully qualified names. If empty, [DefaultMask] is used.
Mask string Mask string
@@ -278,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
func UserAgent() string { func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo() buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s", return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
buildinfo.Main.Version, version,
runtime.GOARCH, runtime.GOARCH,
runtime.GOOS, runtime.GOOS,
runtime.Version(), runtime.Version(),
@@ -425,13 +434,14 @@ func canRetry(err error) bool {
// //
// It always calls update with a nil error. // It always calls update with a nil error.
type trackingReader struct { type trackingReader struct {
r io.Reader l *Layer
n *atomic.Int64 r io.Reader
update func(l *Layer, n int64, err error)
} }
func (r *trackingReader) Read(p []byte) (n int, err error) { func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p) n, err = r.r.Read(p)
r.n.Add(int64(n)) r.update(r.l, int64(n), nil)
return return
} }
@@ -447,6 +457,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil { if err != nil {
return err return err
} }
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 { if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid) return fmt.Errorf("%w: no layers", ErrManifestInvalid)
} }
@@ -456,11 +471,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
exists := func(l *Layer) bool { // TODO(bmizerany): work to remove the need to do this
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
layers := m.Layers layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() { if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config) layers = append(layers, m.Config)
@@ -468,99 +479,97 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an // Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts. // understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx) t := traceFromContext(ctx)
skip := make([]bool, len(layers)) for _, l := range layers {
for i, l := range layers {
t.update(l, 0, nil) t.update(l, 0, nil)
if exists(l) { expected += l.Size
skip[i] = true
t.update(l, l.Size, ErrCached)
}
} }
g, ctx := errgroup.WithContext(ctx) var received atomic.Int64
var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for i, l := range layers { for _, l := range layers {
if skip[i] { info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
received.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue continue
} }
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size) chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil { if err != nil {
t.update(l, 0, err) t.update(l, 0, err)
continue continue
} }
defer chunked.Close()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
t.update(l, progress.Load(), err) // Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break break
} }
wg.Add(1)
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }() defer func() {
if err == nil {
for _, err := range backoff.Loop(ctx, 3*time.Second) { received.Add(cs.Chunk.Size())
if err != nil { } else {
return err err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
} }
err := func() error { wg.Done()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) }()
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
// progress, as they arrive, so if err != nil {
// that our bytes piggyback return err
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
} }
return nil req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
}) })
} }
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil { if err := blob.PutBytes(c, md, m.Data); err != nil {
return err return err
} }
// commit the manifest with a link
return c.Link(m.Name, md) return c.Link(m.Name, md)
} }

View File

@@ -17,6 +17,7 @@ import (
"reflect" "reflect"
"slices" "slices"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -24,6 +25,28 @@ import (
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
) )
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) { func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object. // All manifests should contain an "empty" config object.
var m Manifest var m Manifest
@@ -56,21 +79,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are: // newClient constructs a cache with predefined manifests for testing. The manifests are:
// //
// empty: no data // empty: no data
// zero: no layers // zero: no layers
// single: one layer with the contents "exists" // single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here" // multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache // notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null]) // null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) // sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data // invalid: a layer with invalid JSON data
// //
// Tests that want to ensure the client does not communicate with the upstream // Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if // registry should pass a nil handler, which will cause a panic if
// communication is attempted. // communication is attempted.
// //
// To simulate a network error, pass a handler that returns a 499 status code. // To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper() t.Helper()
c, err := blob.Open(t.TempDir()) c, err := blob.Open(t.TempDir())
@@ -88,7 +111,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
r := &Registry{ r := &Registry{
Cache: c, Cache: c,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(upstreamRegistry),
}, },
} }
@@ -767,3 +790,79 @@ func TestUnlink(t *testing.T) {
} }
}) })
} }
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}

View File

@@ -200,7 +200,7 @@ type params struct {
// //
// Unfortunately, this API was designed to be a bit awkward. Stream is // Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check // defined to default to true if not present, so we need a way to check
// if the client decisively it to false. So, we use a pointer to a // if the client decisively set it to false. So, we use a pointer to a
// bool. Gross. // bool. Gross.
// //
// Use [stream()] to get the correct value for this field. // Use [stream()] to get the correct value for this field.
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
progress := make(map[*ollama.Layer]int64) progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress)) progressCopy := make(map[*ollama.Layer]int64, len(progress))
pushUpdate := func() { flushProgress := func() {
defer maybeFlush() defer maybeFlush()
// TODO(bmizerany): This scales poorly with more layers due to // TODO(bmizerany): Flushing every layer in one update doesn't
// needing to flush out them all in one big update. We _could_ // scale well. We could flush only the modified layers or track
// just flush on the changed ones, or just track the whole // the full download. Needs further consideration, though it's
// download. Needs more thought. This is fine for now. // fine for now.
mu.Lock() mu.Lock()
maps.Copy(progressCopy, progress) maps.Copy(progressCopy, progress)
mu.Unlock() mu.Unlock()
for l, n := range progress { for l, n := range progressCopy {
enc.Encode(progressUpdateJSON{ enc.Encode(progressUpdateJSON{
Digest: l.Digest, Digest: l.Digest,
Total: l.Size, Total: l.Size,
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}) })
} }
} }
defer flushProgress()
t := time.NewTicker(time.Hour) // "unstarted" timer t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() { start := sync.OnceFunc(func() {
pushUpdate() flushProgress() // flush initial state
t.Reset(100 * time.Millisecond) t.Reset(100 * time.Millisecond)
}) })
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) { Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 { if n > 0 {
start() // flush initial state // Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
} }
mu.Lock() mu.Lock()
progress[l] = n progress[l] += n
mu.Unlock() mu.Unlock()
}, },
}) })
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
for { for {
select { select {
case <-t.C: case <-t.C:
pushUpdate() flushProgress()
case err := <-done: case err := <-done:
pushUpdate() flushProgress()
if err != nil { if err != nil {
var status string var status string
if errors.Is(err, ollama.ErrModelNotFound) { if errors.Is(err, ollama.ErrModelNotFound) {

View File

@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers { for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" { if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err, "template", s)
} else { } else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil { if err != nil {

View File

@@ -0,0 +1,13 @@
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<start_of_turn>user
{{- if and (eq $i 1) $.System }}
{{ $.System }}
{{ end }}
{{ .Content }}<end_of_turn>
{{ else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}<end_of_turn>
{{ end }}
{{- if $last }}<start_of_turn>model
{{ end }}
{{- end }}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"<end_of_turn>"
],
"temperature": 0.1
}

View File

@@ -87,6 +87,10 @@
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}", "template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct" "name": "gemma-instruct"
}, },
{
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
"name": "gemma3-instruct"
},
{ {
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct" "name": "llama3-instruct"

View File

@@ -0,0 +1,10 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,4 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,8 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model