mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 14:53:56 +00:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17bb5ea679 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
a69a1e6e63 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 |
@@ -98,7 +98,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
|
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011|1012(:xnack-)|103[0-6]|110[0-3]|1150)$")
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|1201)$")
|
||||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -56,7 +56,7 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1011;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1012:xnack-;"
|
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
@@ -703,6 +703,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
var v string
|
var v string
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
case string:
|
case string:
|
||||||
v = vData
|
v = vData
|
||||||
case float64:
|
case float64:
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ModelInfo: map[string]any{
|
ModelInfo: map[string]any{
|
||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.parameter_count": float64(8_000_000_000),
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
"test.context_length": float64(1000),
|
"test.context_length": float64(1000),
|
||||||
"test.embedding_length": float64(11434),
|
"test.embedding_length": float64(11434),
|
||||||
},
|
},
|
||||||
@@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Metadata
|
Metadata
|
||||||
general.architecture test
|
general.architecture test
|
||||||
general.parameter_count 8e+09
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
test.context_length 1000
|
test.context_length 1000
|
||||||
test.embedding_length 11434
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
|
|||||||
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -43,8 +43,13 @@ type Cache interface {
|
|||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters.
|
||||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
// Close closes the cache and frees resources associated with it
|
// Close closes the cache and frees resources associated with it
|
||||||
Close()
|
Close()
|
||||||
@@ -52,7 +57,7 @@ type Cache interface {
|
|||||||
// StartForward is called before the start of the model's forward pass.
|
// StartForward is called before the start of the model's forward pass.
|
||||||
// For each token in the coming batch, there must be a corresponding
|
// For each token in the coming batch, there must be a corresponding
|
||||||
// entry in positions and seqs.
|
// entry in positions and seqs.
|
||||||
StartForward(ctx ml.Context, opts input.Options) error
|
StartForward(ctx ml.Context, batch input.Batch) error
|
||||||
|
|
||||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
// The mask is of shape history size, batch size
|
// The mask is of shape history size, batch size
|
||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
||||||
c.cells = make([]cacheCell, c.Capacity)
|
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
}
|
}
|
||||||
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(batch.Positions)
|
||||||
c.curSequences = opts.Sequences
|
c.curSequences = batch.Sequences
|
||||||
c.curPositions = opts.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
if errors.Is(err, ErrKvCacheFull) {
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
c.curCellRange = newRange()
|
||||||
for i, pos := range opts.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := opts.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundDown(length, pad int) int {
|
func roundDown(length, pad int) int {
|
||||||
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
return maskTensor, nil
|
return maskTensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
continue
|
continue
|
||||||
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
value := c.values[i]
|
value := c.values[i]
|
||||||
var vSrcView, vDstView ml.Tensor
|
var vSrcView, vDstView ml.Tensor
|
||||||
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
vHeadDim := value.Dim(1)
|
vHeadDim := value.Dim(1)
|
||||||
elemSize := value.Stride(0)
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
vHeadDim := value.Dim(0)
|
vHeadDim := value.Dim(0)
|
||||||
rowSize := value.Stride(2)
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
|
|||||||
ctx := c.backend.NewContext()
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
// For every move, 6 tensors are required per layer (2 views and a
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
// copy for each of k and v).
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
layers := 0
|
layers := 0
|
||||||
for _, key := range c.keys {
|
for _, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
elemSize := c.values[c.curLayer].Stride(0)
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
} else {
|
} else {
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
|||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF32, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "SlidingWindow",
|
name: "FirstBatch",
|
||||||
in: []float32{1, 2, 3, 4},
|
in: []float32{1, 2, 3, 4},
|
||||||
inShape: []int{1, 1, 4},
|
inShape: []int{1, 1, 4},
|
||||||
seqs: []int{0, 0, 0, 0},
|
seqs: []int{0, 0, 0, 0},
|
||||||
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
|||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
|
|||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
context := backend.NewContext()
|
context := backend.NewContext()
|
||||||
defer context.Close()
|
defer context.Close()
|
||||||
|
|
||||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
c.config = &config
|
c.config = &config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
}
|
}
|
||||||
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
// We work with the most recent image
|
// We work with the most recent image
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Init(backend, dtype, capacity)
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
for i, cache := range c.caches {
|
for i, cache := range c.caches {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
for j := i - 1; j >= 0; j-- {
|
for j := i - 1; j >= 0; j-- {
|
||||||
for k := range opts.Positions {
|
for k := range batch.Positions {
|
||||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -60,6 +61,10 @@ type CacheConfig struct {
|
|||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
|
// Progress is a callback function that allows reporting percentage completion
|
||||||
|
// of model loading
|
||||||
|
Progress func(float32)
|
||||||
|
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
||||||
@@ -76,9 +81,9 @@ type BackendParams struct {
|
|||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
||||||
|
|
||||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
||||||
if _, ok := backends[name]; ok {
|
if _, ok := backends[name]; ok {
|
||||||
panic("backend: backend already registered")
|
panic("backend: backend already registered")
|
||||||
}
|
}
|
||||||
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
|||||||
backends[name] = f
|
backends[name] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
||||||
if backend, ok := backends["ggml"]; ok {
|
if backend, ok := backends["ggml"]; ok {
|
||||||
return backend(f, params)
|
return backend(ctx, f, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
|
|||||||
@@ -9,15 +9,17 @@ package ggml
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -58,7 +60,7 @@ type Backend struct {
|
|||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||||
meta, n, err := fs.Decode(r, -1)
|
meta, n, err := fs.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
var doneBytes atomic.Uint64
|
||||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
totalBytes := uint64(n) - meta.Tensors().Offset
|
||||||
var g errgroup.Group
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
for _, t := range meta.Tensors().Items() {
|
for _, t := range meta.Tensors().Items() {
|
||||||
for _, target := range targets[t.Name] {
|
g.Go(func() error {
|
||||||
g.Go(func() error {
|
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
||||||
|
for i := range tts {
|
||||||
|
target := targets[t.Name][i]
|
||||||
if target == "" {
|
if target == "" {
|
||||||
target = t.Name
|
target = t.Name
|
||||||
}
|
}
|
||||||
@@ -312,23 +318,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := make([]byte, t.Size())
|
tts[i] = tt
|
||||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
}
|
||||||
|
|
||||||
|
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||||
|
bts := make([]byte, 128*format.KibiByte)
|
||||||
|
|
||||||
|
var s uint64
|
||||||
|
for s < t.Size() {
|
||||||
|
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if n != len(bts) {
|
for _, tt := range tts {
|
||||||
return errors.New("short read")
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
s += uint64(n)
|
||||||
return nil
|
|
||||||
})
|
if params.Progress != nil {
|
||||||
}
|
done := doneBytes.Add(uint64(n))
|
||||||
|
params.Progress(float32(done) / float32(totalBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.Wait() != nil {
|
// start a goroutine to cancel the errgroup if the parent context is done
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
g.Go(func() error {
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,7 +398,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||||
C.int(len(schedBackends)),
|
C.int(len(schedBackends)),
|
||||||
C.size_t(maxGraphNodes),
|
C.size_t(maxGraphNodes),
|
||||||
true,
|
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||||
),
|
),
|
||||||
input: deviceBufferTypes[input.d],
|
input: deviceBufferTypes[input.d],
|
||||||
output: deviceBufferTypes[output.d],
|
output: deviceBufferTypes[output.d],
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package input
|
package input
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
||||||
// Input represents one token in the input stream
|
// Input represents one token in the input stream
|
||||||
type Input struct {
|
type Input struct {
|
||||||
// Token is a single element of text.
|
// Token is a single element of text.
|
||||||
@@ -33,11 +35,24 @@ type MultimodalIndex struct {
|
|||||||
Multimodal any
|
Multimodal any
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options contains the inputs for a model forward pass
|
// Batch contains the inputs for a model forward pass
|
||||||
type Options struct {
|
type Batch struct {
|
||||||
Inputs []int32
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
|
Inputs ml.Tensor
|
||||||
|
|
||||||
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
|
// models or for batches without multimodal elements.
|
||||||
Multimodal []MultimodalIndex
|
Multimodal []MultimodalIndex
|
||||||
Positions []int32
|
|
||||||
Sequences []int
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
Outputs []int32
|
// in length to Inputs.
|
||||||
|
Positions []int32
|
||||||
|
|
||||||
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
|
Sequences []int
|
||||||
|
|
||||||
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
|
// be returned.
|
||||||
|
Outputs []int32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
|||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||||
|
|
||||||
Backend() ml.Backend
|
Backend() ml.Backend
|
||||||
Config() config
|
Config() config
|
||||||
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
r, err := os.Open(modelPath)
|
r, err := os.Open(modelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
b, err := ml.NewBackend(r, params)
|
b, err := ml.NewBackend(ctx, r, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(opts.Positions) != len(opts.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.Positions) < 1 {
|
if len(batch.Positions) < 1 {
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, opts)
|
t, err := m.Forward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
|
|||||||
|
|
||||||
type notTextProcessorModel struct{}
|
type notTextProcessorModel struct{}
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
if len(m.Layers) == gemma27BLayerCount {
|
||||||
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
for _, image := range opts.Multimodal {
|
for _, image := range batch.Multimodal {
|
||||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|||||||
@@ -135,32 +135,27 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return inputs, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
crossAttentionStates = images[len(images)-1]
|
crossAttentionStates = images[len(images)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ type InputCache struct {
|
|||||||
cache kvcache.Cache
|
cache kvcache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||||
if kvSize/int32(numSlots) < 1 {
|
numCtx := kvSize / int32(numSlots)
|
||||||
|
|
||||||
|
if numCtx < 1 {
|
||||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
|
|
||||||
cache := model.Config().Cache
|
cache := model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &InputCache{
|
return &InputCache{
|
||||||
numCtx: kvSize / int32(numSlots),
|
numCtx: numCtx,
|
||||||
enabled: cache != nil,
|
enabled: cache != nil,
|
||||||
slots: slots,
|
slots: slots,
|
||||||
multiUserCache: multiUserCache,
|
multiUserCache: multiUserCache,
|
||||||
@@ -89,7 +91,7 @@ type InputCacheSlot struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
@@ -107,11 +109,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
|
|
||||||
if !cachePrompt {
|
|
||||||
numPast = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.InUse = true
|
slot.InUse = true
|
||||||
slot.lastUsed = time.Now()
|
slot.lastUsed = time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadCacheSlot(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cache InputCache
|
||||||
|
prompt []input.Input
|
||||||
|
wantErr bool
|
||||||
|
expectedSlotId int
|
||||||
|
expectedPrompt int // expected length of remaining prompt
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic cache hit - single user",
|
||||||
|
cache: InputCache{
|
||||||
|
multiUserCache: false,
|
||||||
|
slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input.Input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
|
wantErr: false,
|
||||||
|
expectedSlotId: 0,
|
||||||
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Basic cache hit - multi user",
|
||||||
|
cache: InputCache{
|
||||||
|
multiUserCache: true,
|
||||||
|
slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Inputs: []input.Input{},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
|
wantErr: false,
|
||||||
|
expectedSlotId: 0,
|
||||||
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exact match - leave one input",
|
||||||
|
cache: InputCache{
|
||||||
|
multiUserCache: false,
|
||||||
|
slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||||
|
InUse: false,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||||
|
wantErr: false,
|
||||||
|
expectedSlotId: 0,
|
||||||
|
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No available slots",
|
||||||
|
cache: InputCache{
|
||||||
|
multiUserCache: false,
|
||||||
|
slots: []InputCacheSlot{
|
||||||
|
{
|
||||||
|
Id: 0,
|
||||||
|
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||||
|
InUse: true,
|
||||||
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
|
wantErr: true,
|
||||||
|
expectedSlotId: -1,
|
||||||
|
expectedPrompt: -1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||||
|
|
||||||
|
// Check error state
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
return // Skip further checks if we expected an error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify slot ID
|
||||||
|
if slot.Id != tt.expectedSlotId {
|
||||||
|
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify slot is now marked in use
|
||||||
|
if !slot.InUse {
|
||||||
|
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify remaining prompt length
|
||||||
|
if len(remainingPrompt) != tt.expectedPrompt {
|
||||||
|
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||||
|
len(remainingPrompt), tt.expectedPrompt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
params.numKeep = int32(len(inputs))
|
params.numKeep = int32(len(inputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||||
|
// otherwise we might truncate or split the batch against the model's wishes
|
||||||
|
|
||||||
// Ensure that at least 1 input can be discarded during shift
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
@@ -345,7 +348,8 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var options input.Options
|
var batchInputs []int32
|
||||||
|
var batch input.Batch
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
@@ -366,17 +370,6 @@ func (s *Server) processBatch() error {
|
|||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
|
|
||||||
for j, inp := range seq.inputs {
|
for j, inp := range seq.inputs {
|
||||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
|
||||||
if len(seq.pendingInputs) == 0 {
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we are required to put following inputs into a single batch then extend the
|
// If we are required to put following inputs into a single batch then extend the
|
||||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||||
// will cause a break if we have pending inputs.
|
// will cause a break if we have pending inputs.
|
||||||
@@ -389,17 +382,31 @@ func (s *Server) processBatch() error {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||||
if inp.Multimodal != nil {
|
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
// now so we don't have to do one later when we can't break the batch.
|
||||||
|
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||||
|
if len(seq.pendingInputs) != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batchInputs = append(batchInputs, inp.Token)
|
||||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
if inp.Multimodal != nil {
|
||||||
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||||
|
}
|
||||||
|
|
||||||
seq.iBatch = len(options.Outputs)
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
|
seq.iBatch = len(batch.Outputs)
|
||||||
if j+1 == len(seq.inputs) {
|
if j+1 == len(seq.inputs) {
|
||||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
@@ -407,14 +414,14 @@ func (s *Server) processBatch() error {
|
|||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.Inputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
ctx := s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
return fmt.Errorf("failed to decode batch: %w", err)
|
||||||
}
|
}
|
||||||
@@ -454,7 +461,7 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(options.Outputs)
|
vocabSize := len(logits) / len(batch.Outputs)
|
||||||
|
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -590,7 +597,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
@@ -671,6 +678,7 @@ func (m *multiLPath) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
|
ctx context.Context,
|
||||||
mpath string,
|
mpath string,
|
||||||
params ml.BackendParams,
|
params ml.BackendParams,
|
||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
@@ -680,7 +688,7 @@ func (s *Server) loadModel(
|
|||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) {
|
||||||
var err error
|
var err error
|
||||||
s.model, err = model.New(mpath, params)
|
s.model, err = model.New(ctx, mpath, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -692,7 +700,7 @@ func (s *Server) loadModel(
|
|||||||
panic("loras are not yet implemented")
|
panic("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -776,6 +784,9 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
|
Progress: func(progress float32) {
|
||||||
|
server.progress = progress
|
||||||
|
},
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
@@ -784,13 +795,13 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
go server.run(ctx)
|
go server.run(ctx)
|
||||||
|
|
||||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Sampler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
}
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@@ -87,19 +91,13 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
// topK also sorts the tokens in descending order of logits
|
// topK also sorts the tokens in descending order of logits
|
||||||
tokens = topK(tokens, s.topK)
|
tokens = topK(tokens, s.topK)
|
||||||
|
|
||||||
tokens = temperature(tokens, s.temperature)
|
// scale and normalize the tokens in place
|
||||||
tokens = softmax(tokens)
|
temperature(tokens, s.temperature)
|
||||||
|
softmax(tokens)
|
||||||
|
|
||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
var r float32
|
||||||
if s.rng != nil {
|
if s.rng != nil {
|
||||||
r = s.rng.Float32()
|
r = s.rng.Float32()
|
||||||
@@ -122,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return 1
|
return 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if math.IsNaN(float64(sum)) {
|
||||||
|
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||||
|
}
|
||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
|||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test very high p
|
||||||
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
|
// Use extremely small topP to filter out all tokens
|
||||||
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Should get the token with the highest logit
|
||||||
|
want = int32(0)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
|
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got %d", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
|||||||
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
// temperature applies scaling to the logits
|
||||||
func temperature(ts []token, temp float32) []token {
|
func temperature(ts []token, temp float32) {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
temp = max(temp, 1e-7)
|
temp = max(temp, 1e-7)
|
||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].value = ts[i].value / temp
|
ts[i].value = ts[i].value / temp
|
||||||
}
|
}
|
||||||
return ts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax applies normalization to the logits
|
// softmax applies normalization to the logits
|
||||||
func softmax(ts []token) []token {
|
func softmax(ts []token) {
|
||||||
// Find max logit for numerical stability
|
// Find max logit for numerical stability
|
||||||
maxLogit := float32(math.Inf(-1))
|
maxLogit := float32(math.Inf(-1))
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
|||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].value /= sum
|
ts[i].value /= sum
|
||||||
}
|
}
|
||||||
|
|
||||||
return ts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// topK limits the number of tokens considered to the k highest logits
|
// topK limits the number of tokens considered to the k highest logits
|
||||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// topP limits tokens to those with cumulative probability p
|
// topP limits tokens to those with cumulative probability p
|
||||||
|
// requires ts to be sorted in descending order of probabilities
|
||||||
func topP(ts []token, p float32) []token {
|
func topP(ts []token, p float32) []token {
|
||||||
if p == 1.0 {
|
if p == 1.0 {
|
||||||
return ts
|
return ts
|
||||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
|||||||
for i, t := range ts {
|
for i, t := range ts {
|
||||||
sum += t.value
|
sum += t.value
|
||||||
if sum > float32(p) {
|
if sum > float32(p) {
|
||||||
ts = ts[:i+1]
|
return ts[:i+1]
|
||||||
return ts
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
// minP limits tokens to those with cumulative probability p
|
// minP filters tokens with probabilities >= p * max_prob
|
||||||
|
// requires ts to be sorted in descending order of probabilities
|
||||||
func minP(ts []token, p float32) []token {
|
func minP(ts []token, p float32) []token {
|
||||||
if p == 1.0 {
|
maxProb := ts[0].value
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
maxProb := float32(math.Inf(-1))
|
threshold := maxProb * p
|
||||||
for _, token := range ts {
|
|
||||||
if token.value > maxProb {
|
for i, t := range ts {
|
||||||
maxProb = token.value
|
if t.value < threshold {
|
||||||
|
return ts[:i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threshold := maxProb * float32(p)
|
|
||||||
|
|
||||||
// Filter tokens in-place
|
|
||||||
validTokens := ts[:0]
|
|
||||||
for i, token := range ts {
|
|
||||||
if token.value >= threshold {
|
|
||||||
validTokens = append(validTokens, ts[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ts = validTokens
|
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
|||||||
|
|
||||||
func TestTemperature(t *testing.T) {
|
func TestTemperature(t *testing.T) {
|
||||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
got := temperature(toTokens(input), 0.5)
|
tokens := toTokens(input)
|
||||||
|
temperature(tokens, 0.5)
|
||||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||||
compareLogits(t, "temperature(0.5)", want, got)
|
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||||
|
|
||||||
got = temperature(toTokens(input), 1.0)
|
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
|
tokens = toTokens(input)
|
||||||
|
temperature(tokens, 1.0)
|
||||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
compareLogits(t, "temperature(1)", want, got)
|
compareLogits(t, "temperature(1)", want, tokens)
|
||||||
|
|
||||||
got = temperature(toTokens(input), 0.0)
|
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||||
|
tokens = toTokens(input)
|
||||||
|
temperature(tokens, 0.0)
|
||||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||||
compareLogits(t, "temperature(0)", want, got)
|
compareLogits(t, "temperature(0)", want, tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestSoftmax(t *testing.T) {
|
||||||
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := softmax(toTokens(tt.input))
|
tokens := toTokens(tt.input)
|
||||||
|
softmax(tokens)
|
||||||
|
|
||||||
if tt.expected != nil {
|
if tt.expected != nil {
|
||||||
compareLogits(t, tt.name, tt.expected, got)
|
compareLogits(t, tt.name, tt.expected, tokens)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check probabilities sum to 1
|
// Check probabilities sum to 1
|
||||||
var sum float32
|
var sum float32
|
||||||
for _, token := range got {
|
for _, token := range tokens {
|
||||||
sum += token.value
|
sum += token.value
|
||||||
if token.value < 0 || token.value > 1 {
|
if token.value < 0 || token.value > 1 {
|
||||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||||
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
|
|||||||
|
|
||||||
func TestTopK(t *testing.T) {
|
func TestTopK(t *testing.T) {
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
|
tokens := toTokens(input)
|
||||||
// Test k=5
|
tokens = topK(tokens, 5)
|
||||||
got := topK(toTokens(input), 5)
|
if len(tokens) != 5 {
|
||||||
if len(got) != 5 {
|
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
|
||||||
}
|
}
|
||||||
// Should keep highest 3 values in descending order
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||||
compareLogits(t, "topK(3)", want, got)
|
compareLogits(t, "topK(3)", want, tokens)
|
||||||
|
|
||||||
got = topK(toTokens(input), 20)
|
tokens = toTokens(input)
|
||||||
if len(got) != len(input) {
|
tokens = topK(tokens, 20)
|
||||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
if len(tokens) != len(input) {
|
||||||
|
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test k=-1
|
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||||
got = topK(toTokens(input), -1)
|
tokens = toTokens(input)
|
||||||
if len(got) != len(input) {
|
tokens = topK(tokens, -1)
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
if len(tokens) != len(input) {
|
||||||
|
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||||
}
|
}
|
||||||
compareLogits(t, "topK(-1)", want, got)
|
compareLogits(t, "topK(-1)", want, tokens)
|
||||||
|
|
||||||
// Test k=0
|
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||||
got = topK(toTokens(input), 0)
|
tokens = toTokens(input)
|
||||||
if len(got) != len(input) {
|
tokens = topK(tokens, 0)
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
if len(tokens) != len(input) {
|
||||||
|
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||||
|
}
|
||||||
|
compareLogits(t, "topK(-1)", want, tokens)
|
||||||
|
|
||||||
|
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 1)
|
||||||
|
if len(tokens) < 1 {
|
||||||
|
t.Error("topK should keep at least one token")
|
||||||
}
|
}
|
||||||
compareLogits(t, "topK(-1)", want, got)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopP(t *testing.T) {
|
func TestTopP(t *testing.T) {
|
||||||
@@ -153,50 +165,134 @@ func TestTopP(t *testing.T) {
|
|||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax to get probabilities
|
// First apply temperature and softmax to get probabilities
|
||||||
tokens = softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Test with very high p value
|
||||||
got := topP(tokens, 0.95)
|
got := topP(tokens, 1.0)
|
||||||
|
|
||||||
|
// Should keep all tokens since p is 1
|
||||||
|
if len(got) != len(input) {
|
||||||
|
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = topP(tokens, 0.95)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
|
||||||
if len(got) > 3 {
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test edge case - ensure at least one token remains
|
||||||
|
input = []float32{-1e6, -1e6, -1e7}
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
if len(got) < 1 {
|
||||||
|
t.Error("topP should keep at least one token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 1e-10)
|
||||||
|
if len(got) == 0 {
|
||||||
|
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||||
t.Logf("got: %v", got)
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
tokens = softmax(tokens)
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
|
||||||
// Then apply minP
|
tokens = minP(tokens, 1.0)
|
||||||
got := minP(tokens, 0.2)
|
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
tokens = toTokens(input) // Reset tokens
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.2)
|
||||||
|
|
||||||
// Should keep tokens with prob >= 0.2 * max_prob
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
if len(got) > 3 {
|
if len(tokens) > 3 {
|
||||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
||||||
tokens := toTokens(input)
|
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
tokens = toTokens(input) // Reset tokens
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.0)
|
||||||
|
|
||||||
for i := 1; i < len(tokens); i++ {
|
// Should keep only the highest probability token
|
||||||
if tokens[i].value > tokens[i-1].value {
|
if len(tokens) != len(input) {
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with single token
|
||||||
|
tokens = toTokens(input[:1])
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.1)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
input = []float32{1e-10, 1e-10, 1e-10}
|
||||||
|
tokens = toTokens(input)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 1.0)
|
||||||
|
if len(tokens) < 1 {
|
||||||
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||||
|
got := minP(tokens, 1.0)
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = minP(tokens, 0.2)
|
||||||
|
|
||||||
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
|
if len(got) > 3 {
|
||||||
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = minP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != len(tokens) {
|
||||||
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
@@ -231,7 +327,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
topK(tokensCopy, 10)
|
tokens = topK(tokensCopy, 10)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -239,7 +335,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
topP(tokensCopy, 0.9)
|
tokens = topP(tokensCopy, 0.9)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -247,7 +343,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
minP(tokensCopy, 0.2)
|
tokens = minP(tokensCopy, 0.2)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -255,7 +351,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
copy(tokensCopy, tokens)
|
copy(tokensCopy, tokens)
|
||||||
topK(tokensCopy, 200000)
|
tokens = topK(tokensCopy, 200000)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -60,6 +59,11 @@ var (
|
|||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
ErrCached = errors.New("cached")
|
ErrCached = errors.New("cached")
|
||||||
|
|
||||||
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||||
|
// incomplete due to one or more layer download failures. Users that
|
||||||
|
// want specific errors should use [WithTrace].
|
||||||
|
ErrIncomplete = errors.New("incomplete")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults
|
// Defaults
|
||||||
@@ -213,12 +217,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@@ -278,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
|
|
||||||
func UserAgent() string {
|
func UserAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
version := buildinfo.Main.Version
|
||||||
|
if version == "(devel)" {
|
||||||
|
// When using `go run .` the version is "(devel)". This is seen
|
||||||
|
// as an invalid version by ollama.com and so it defaults to
|
||||||
|
// "needs upgrade" for some requests, such as pulls. These
|
||||||
|
// checks can be skipped by using the special version "v0.0.0",
|
||||||
|
// so we set it to that here.
|
||||||
|
version = "v0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||||
buildinfo.Main.Version,
|
version,
|
||||||
runtime.GOARCH,
|
runtime.GOARCH,
|
||||||
runtime.GOOS,
|
runtime.GOOS,
|
||||||
runtime.Version(),
|
runtime.Version(),
|
||||||
@@ -425,13 +434,14 @@ func canRetry(err error) bool {
|
|||||||
//
|
//
|
||||||
// It always calls update with a nil error.
|
// It always calls update with a nil error.
|
||||||
type trackingReader struct {
|
type trackingReader struct {
|
||||||
r io.Reader
|
l *Layer
|
||||||
n *atomic.Int64
|
r io.Reader
|
||||||
|
update func(l *Layer, n int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||||
n, err = r.r.Read(p)
|
n, err = r.r.Read(p)
|
||||||
r.n.Add(int64(n))
|
r.update(r.l, int64(n), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,6 +457,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@@ -456,11 +471,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@@ -468,99 +479,97 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
|
var expected int64
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
expected += l.Size
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var received atomic.Int64
|
||||||
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
received.Add(l.Size)
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, progress.Load(), err)
|
// Chunksum stream interrupted. Note in trace
|
||||||
|
// log and let in-flight downloads complete.
|
||||||
|
// This will naturally trigger ErrIncomplete
|
||||||
|
// since received < expected bytes.
|
||||||
|
t.update(l, 0, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
received.Add(cs.Chunk.Size())
|
||||||
if err != nil {
|
} else {
|
||||||
return err
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
err := func() error {
|
wg.Done()
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
}()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
// Count bytes towards
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
// progress, as they arrive, so
|
if err != nil {
|
||||||
// that our bytes piggyback
|
return err
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||||
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close writer immediately after downloads finish, not at Pull
|
||||||
|
// exit. Using defer would keep file descriptors open until all
|
||||||
|
// layers complete, potentially exhausting system limits with
|
||||||
|
// many layers.
|
||||||
|
//
|
||||||
|
// The WaitGroup tracks when all chunks finish downloading,
|
||||||
|
// allowing precise writer closure in a background goroutine.
|
||||||
|
// Each layer briefly uses one extra goroutine while at most
|
||||||
|
// maxStreams()-1 chunks download in parallel.
|
||||||
|
//
|
||||||
|
// This caps file descriptors at maxStreams() instead of
|
||||||
|
// growing with layer count.
|
||||||
|
g.Go(func() error {
|
||||||
|
wg.Wait()
|
||||||
|
chunked.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if received.Load() != expected {
|
||||||
|
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,6 +25,28 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ExampleRegistry_cancelOnFirstError() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
// Discontinue pulling layers if there is an
|
||||||
|
// error instead of continuing to pull more
|
||||||
|
// data.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var r Registry
|
||||||
|
if err := r.Pull(ctx, "model"); err != nil {
|
||||||
|
// panic for demo purposes
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestMarshalJSON(t *testing.T) {
|
func TestManifestMarshalJSON(t *testing.T) {
|
||||||
// All manifests should contain an "empty" config object.
|
// All manifests should contain an "empty" config object.
|
||||||
var m Manifest
|
var m Manifest
|
||||||
@@ -56,21 +79,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
|
|
||||||
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
||||||
//
|
//
|
||||||
// empty: no data
|
// empty: no data
|
||||||
// zero: no layers
|
// zero: no layers
|
||||||
// single: one layer with the contents "exists"
|
// single: one layer with the contents "exists"
|
||||||
// multiple: two layers with the contents "exists" and "here"
|
// multiple: two layers with the contents "exists" and "here"
|
||||||
// notfound: a layer that does not exist in the cache
|
// notfound: a layer that does not exist in the cache
|
||||||
// null: one null layer (e.g. [null])
|
// null: one null layer (e.g. [null])
|
||||||
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
||||||
// invalid: a layer with invalid JSON data
|
// invalid: a layer with invalid JSON data
|
||||||
//
|
//
|
||||||
// Tests that want to ensure the client does not communicate with the upstream
|
// Tests that want to ensure the client does not communicate with the upstream
|
||||||
// registry should pass a nil handler, which will cause a panic if
|
// registry should pass a nil handler, which will cause a panic if
|
||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@@ -88,7 +111,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,3 +790,79 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPullChunksums(t *testing.T) {
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
content := "hello"
|
||||||
|
var chunksums string
|
||||||
|
contentDigest := func() blob.Digest {
|
||||||
|
return blob.DigestFromBytes(content)
|
||||||
|
}
|
||||||
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||||
|
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||||
|
w.Header().Set("Content-Location", loc)
|
||||||
|
io.WriteString(w, chunksums)
|
||||||
|
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||||
|
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected request: %v", r)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||||
|
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Logf("Update: %v %d %v", l, n, err)
|
||||||
|
mu.Lock()
|
||||||
|
reads = append(reads, n)
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||||
|
blob.DigestFromBytes("hel"),
|
||||||
|
blob.DigestFromBytes("lo"),
|
||||||
|
)
|
||||||
|
err := rc.Pull(ctx, "test")
|
||||||
|
check(err)
|
||||||
|
wantReads := []int64{
|
||||||
|
0, // initial signaling of layer pull starting
|
||||||
|
3, // first chunk read
|
||||||
|
2, // second chunk read
|
||||||
|
}
|
||||||
|
if !slices.Equal(reads, wantReads) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||||
|
}
|
||||||
|
|
||||||
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
check(err)
|
||||||
|
mg, err := rc.ResolveLocal("test")
|
||||||
|
check(err)
|
||||||
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
|
}
|
||||||
|
for i := range mg.Layers {
|
||||||
|
_, err = c.Get(mg.Layers[i].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing chunks
|
||||||
|
content = "llama"
|
||||||
|
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||||
|
err = rc.Pull(ctx, "missingchunks")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error because of missing chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ type params struct {
|
|||||||
//
|
//
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||||
// defined to default to true if not present, so we need a way to check
|
// defined to default to true if not present, so we need a way to check
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
// if the client decisively set it to false. So, we use a pointer to a
|
||||||
// bool. Gross.
|
// bool. Gross.
|
||||||
//
|
//
|
||||||
// Use [stream()] to get the correct value for this field.
|
// Use [stream()] to get the correct value for this field.
|
||||||
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
progress := make(map[*ollama.Layer]int64)
|
progress := make(map[*ollama.Layer]int64)
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||||
pushUpdate := func() {
|
flushProgress := func() {
|
||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): Flushing every layer in one update doesn't
|
||||||
// needing to flush out them all in one big update. We _could_
|
// scale well. We could flush only the modified layers or track
|
||||||
// just flush on the changed ones, or just track the whole
|
// the full download. Needs further consideration, though it's
|
||||||
// download. Needs more thought. This is fine for now.
|
// fine for now.
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
maps.Copy(progressCopy, progress)
|
maps.Copy(progressCopy, progress)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
for l, n := range progress {
|
for l, n := range progressCopy {
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer flushProgress()
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
|
||||||
start := sync.OnceFunc(func() {
|
start := sync.OnceFunc(func() {
|
||||||
pushUpdate()
|
flushProgress() // flush initial state
|
||||||
t.Reset(100 * time.Millisecond)
|
t.Reset(100 * time.Millisecond)
|
||||||
})
|
})
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
start() // flush initial state
|
// Block flushing progress updates until every
|
||||||
|
// layer is accounted for. Clients depend on a
|
||||||
|
// complete model size to calculate progress
|
||||||
|
// correctly; if they use an incomplete total,
|
||||||
|
// progress indicators would erratically jump
|
||||||
|
// as new layers are registered.
|
||||||
|
start()
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
progress[l] = n
|
progress[l] += n
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var status string
|
var status string
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<start_of_turn>user
|
||||||
|
{{- if and (eq $i 1) $.System }}
|
||||||
|
{{ $.System }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ end }}
|
||||||
|
{{- if $last }}<start_of_turn>model
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<end_of_turn>"
|
||||||
|
],
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
@@ -87,6 +87,10 @@
|
|||||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"name": "gemma-instruct"
|
"name": "gemma-instruct"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||||
|
"name": "gemma3-instruct"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
"name": "llama3-instruct"
|
"name": "llama3-instruct"
|
||||||
|
|||||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
Reference in New Issue
Block a user