mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -416,6 +416,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -456,6 +457,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
@@ -534,6 +536,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
@@ -703,6 +703,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
var v string
|
var v string
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
case string:
|
case string:
|
||||||
v = vData
|
v = vData
|
||||||
case float64:
|
case float64:
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ModelInfo: map[string]any{
|
ModelInfo: map[string]any{
|
||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.parameter_count": float64(8_000_000_000),
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
"test.context_length": float64(1000),
|
"test.context_length": float64(1000),
|
||||||
"test.embedding_length": float64(11434),
|
"test.embedding_length": float64(11434),
|
||||||
},
|
},
|
||||||
@@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Metadata
|
Metadata
|
||||||
general.architecture test
|
general.architecture test
|
||||||
general.parameter_count 8e+09
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
test.context_length 1000
|
test.context_length 1000
|
||||||
test.embedding_length 11434
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
|
|||||||
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
|||||||
On **Linux** systems with systemd, the logs can be found with this command:
|
On **Linux** systems with systemd, the logs can be found with this command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
journalctl -u ollama --no-pager
|
journalctl -u ollama --no-pager --follow --pager-end
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := f.KV().HeadCount()
|
heads := f.KV().HeadCount()
|
||||||
headsKV := f.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKV()
|
||||||
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||||
kv = headsKV *
|
for i := range kv {
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||||
(2* // sizeof(float16)
|
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
|
||||||
context +
|
|
||||||
4 * // sizeof(float32)
|
4 * // sizeof(float32)
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
|
||||||
visionTokens *
|
visionTokens *
|
||||||
tiles)
|
tiles
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
4*embeddingHeadsK*context*8+
|
4*embeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
// engine. Gemma3 always uses the Ollama engine.
|
||||||
|
if f.KV().Architecture() == "gemma3" {
|
||||||
|
const gemma3GlobalCacheCount = 6
|
||||||
|
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||||
|
for i := range kv {
|
||||||
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
|
// layers are the smaller local (sliding) layers.
|
||||||
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
|
|||||||
@@ -43,8 +43,13 @@ type Cache interface {
|
|||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters.
|
||||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
// Close closes the cache and frees resources associated with it
|
// Close closes the cache and frees resources associated with it
|
||||||
Close()
|
Close()
|
||||||
@@ -52,7 +57,7 @@ type Cache interface {
|
|||||||
// StartForward is called before the start of the model's forward pass.
|
// StartForward is called before the start of the model's forward pass.
|
||||||
// For each token in the coming batch, there must be a corresponding
|
// For each token in the coming batch, there must be a corresponding
|
||||||
// entry in positions and seqs.
|
// entry in positions and seqs.
|
||||||
StartForward(ctx ml.Context, opts input.Options) error
|
StartForward(ctx ml.Context, batch input.Batch) error
|
||||||
|
|
||||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
// The mask is of shape history size, batch size
|
// The mask is of shape history size, batch size
|
||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
||||||
c.cells = make([]cacheCell, c.Capacity)
|
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
}
|
}
|
||||||
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(batch.Positions)
|
||||||
c.curSequences = opts.Sequences
|
c.curSequences = batch.Sequences
|
||||||
c.curPositions = opts.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
if errors.Is(err, ErrKvCacheFull) {
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
c.curCellRange = newRange()
|
||||||
for i, pos := range opts.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := opts.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundDown(length, pad int) int {
|
func roundDown(length, pad int) int {
|
||||||
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
return maskTensor, nil
|
return maskTensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
continue
|
continue
|
||||||
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
value := c.values[i]
|
value := c.values[i]
|
||||||
var vSrcView, vDstView ml.Tensor
|
var vSrcView, vDstView ml.Tensor
|
||||||
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
vHeadDim := value.Dim(1)
|
vHeadDim := value.Dim(1)
|
||||||
elemSize := value.Stride(0)
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
vHeadDim := value.Dim(0)
|
vHeadDim := value.Dim(0)
|
||||||
rowSize := value.Stride(2)
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
|
|||||||
ctx := c.backend.NewContext()
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
// For every move, 6 tensors are required per layer (2 views and a
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
// copy for each of k and v).
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
layers := 0
|
layers := 0
|
||||||
for _, key := range c.keys {
|
for _, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
elemSize := c.values[c.curLayer].Stride(0)
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
} else {
|
} else {
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
|||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF32, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "SlidingWindow",
|
name: "FirstBatch",
|
||||||
in: []float32{1, 2, 3, 4},
|
in: []float32{1, 2, 3, 4},
|
||||||
inShape: []int{1, 1, 4},
|
inShape: []int{1, 1, 4},
|
||||||
seqs: []int{0, 0, 0, 0},
|
seqs: []int{0, 0, 0, 0},
|
||||||
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
|||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
|
|||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
context := backend.NewContext()
|
context := backend.NewContext()
|
||||||
defer context.Close()
|
defer context.Close()
|
||||||
|
|
||||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -352,7 +362,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Input() ml.Context { return c }
|
func (c *testContext) Input() ml.Context { return c }
|
||||||
func (c *testContext) Output() ml.Context { return c }
|
|
||||||
func (c *testContext) Layer(int) ml.Context { return c }
|
func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
c.config = &config
|
c.config = &config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
}
|
}
|
||||||
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
// We work with the most recent image
|
// We work with the most recent image
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Init(backend, dtype, capacity)
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
for i, cache := range c.caches {
|
for i, cache := range c.caches {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
for j := i - 1; j >= 0; j-- {
|
for j := i - 1; j >= 0; j-- {
|
||||||
for k := range opts.Positions {
|
for k := range batch.Positions {
|
||||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Saman <saman.khatir@amd.com>
|
||||||
|
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||||
|
Subject: [PATCH] add rdna4 support
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||||
|
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||||
|
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||||
|
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
index adf0d3ec..b24593fc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/common.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
@@ -61,11 +61,13 @@
|
||||||
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
|
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
|
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
|
-#elif defined(RDNA3)
|
||||||
|
+#elif defined(RDNA3) || defined(RDNA4)
|
||||||
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
|
int tmp1;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
index 10f2ebb1..933d945c 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
index 0451c65f..66ce2bc9 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
#else
|
||||||
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
index 4fb466ca..23ae7abc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||||
|
|
||||||
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
|
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||||
|
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
|
constexpr int nwarps = 1;
|
||||||
|
constexpr int rows_per_cuda_block = 1;
|
||||||
|
#else
|
||||||
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
|
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
index 81964611..a62544b5 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
@@ -150,6 +150,10 @@
|
||||||
|
#define CDNA
|
||||||
|
#endif
|
||||||
|
|
||||||
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
+#define RDNA4
|
||||||
|
+#endif
|
||||||
|
+
|
||||||
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
|
#define RDNA3
|
||||||
@@ -15,12 +15,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||||
// Split up the GPUs by type and try them
|
// Split up the GPUs by type and try them
|
||||||
var estimatedVRAM uint64
|
var estimatedVRAM uint64
|
||||||
for _, gpus := range allGpus.ByLibrary() {
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
var layerCount int
|
var layerCount int
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||||
if opts.NumGPU < 0 {
|
if opts.NumGPU < 0 {
|
||||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
|||||||
|
|
||||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||||
// The GPUs provided must all be the same Library
|
// The GPUs provided must all be the same Library
|
||||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||||
// Graph size for a partial offload, applies to all GPUs
|
// Graph size for a partial offload, applies to all GPUs
|
||||||
var graphPartialOffload uint64
|
var graphPartialOffload uint64
|
||||||
|
|
||||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
if len(kv) > 0 {
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
var kvTotal uint64
|
||||||
|
for _, kvLayer := range kv {
|
||||||
|
kvTotal += kvLayer
|
||||||
|
}
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||||
}
|
}
|
||||||
if graphFullOffload == 0 {
|
if graphFullOffload == 0 {
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
// Some models have inconsistent layer sizes
|
// Some models have inconsistent layer sizes
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
layerSize = blk.Size()
|
layerSize = blk.Size()
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[i]
|
||||||
memoryWeights += blk.Size()
|
memoryWeights += blk.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
layersRequested: opts.NumGPU,
|
layersRequested: opts.NumGPU,
|
||||||
layersModel: int(f.KV().BlockCount()) + 1,
|
layersModel: int(f.KV().BlockCount()) + 1,
|
||||||
availableList: availableList,
|
availableList: availableList,
|
||||||
kv: kv,
|
kv: kvTotal,
|
||||||
allocationsList: allocationsList,
|
allocationsList: allocationsList,
|
||||||
memoryWeights: memoryWeights,
|
memoryWeights: memoryWeights,
|
||||||
memoryLayerOutput: memoryLayerOutput,
|
memoryLayerOutput: memoryLayerOutput,
|
||||||
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
|||||||
slog.Group(
|
slog.Group(
|
||||||
"weights",
|
"weights",
|
||||||
// memory of the weights
|
// memory of the weights
|
||||||
"total", format.HumanBytes2(m.memoryWeights),
|
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||||
// memory of repeating layers
|
// memory of repeating layers
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||||
// memory of non-repeating layers
|
// memory of non-repeating layers
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
projectors := []string{}
|
projectors := []string{}
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
t.Run("cpu", func(t *testing.T) {
|
t.Run("cpu", func(t *testing.T) {
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, 0, estimate.Layers)
|
assert.Equal(t, 0, estimate.Layers)
|
||||||
assert.Equal(t, uint64(0), estimate.Graph)
|
assert.Equal(t, uint64(0), estimate.Graph)
|
||||||
})
|
})
|
||||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||||
var layerSums uint64
|
var layerSums uint64
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
gpus = discover.GetCPUInfo()
|
gpus = discover.GetCPUInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||||
switch {
|
switch {
|
||||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -60,6 +61,10 @@ type CacheConfig struct {
|
|||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
|
// Progress is a callback function that allows reporting percentage completion
|
||||||
|
// of model loading
|
||||||
|
Progress func(float32)
|
||||||
|
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
||||||
@@ -76,9 +81,9 @@ type BackendParams struct {
|
|||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
||||||
|
|
||||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
||||||
if _, ok := backends[name]; ok {
|
if _, ok := backends[name]; ok {
|
||||||
panic("backend: backend already registered")
|
panic("backend: backend already registered")
|
||||||
}
|
}
|
||||||
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
|||||||
backends[name] = f
|
backends[name] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
||||||
if backend, ok := backends["ggml"]; ok {
|
if backend, ok := backends["ggml"]; ok {
|
||||||
return backend(f, params)
|
return backend(ctx, f, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
@@ -105,12 +110,10 @@ type Context interface {
|
|||||||
MaxGraphNodes() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Input returns a context appropriate for creating input tensors
|
// Input returns a context appropriate for creating tensors that are
|
||||||
|
// inputs to the model (which includes things like output locations)
|
||||||
Input() Context
|
Input() Context
|
||||||
|
|
||||||
// Output returns a context appropriate for creating output tensors
|
|
||||||
Output() Context
|
|
||||||
|
|
||||||
// Layer returns a context appropriate for creating intermediate tensors
|
// Layer returns a context appropriate for creating intermediate tensors
|
||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,15 +9,17 @@ package ggml
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -46,9 +48,6 @@ type Backend struct {
|
|||||||
// input is the backend used for inputs
|
// input is the backend used for inputs
|
||||||
input *C.struct_ggml_backend_buffer_type
|
input *C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
// output is the backend used for outputs
|
|
||||||
output *C.struct_ggml_backend_buffer_type
|
|
||||||
|
|
||||||
// layers is the backend used for repeating layers
|
// layers is the backend used for repeating layers
|
||||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
@@ -58,7 +57,7 @@ type Backend struct {
|
|||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||||
meta, n, err := fs.Decode(r, -1)
|
meta, n, err := fs.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -297,12 +296,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
var doneBytes atomic.Uint64
|
||||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
totalBytes := uint64(n) - meta.Tensors().Offset
|
||||||
var g errgroup.Group
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
for _, t := range meta.Tensors().Items() {
|
for _, t := range meta.Tensors().Items() {
|
||||||
for _, target := range targets[t.Name] {
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
||||||
|
for i := range tts {
|
||||||
|
target := targets[t.Name][i]
|
||||||
if target == "" {
|
if target == "" {
|
||||||
target = t.Name
|
target = t.Name
|
||||||
}
|
}
|
||||||
@@ -312,25 +315,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := C.malloc(C.size_t(t.Size()))
|
tts[i] = tt
|
||||||
if bts == nil {
|
}
|
||||||
return errors.New("failed to allocate tensor buffer")
|
|
||||||
}
|
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||||
defer C.free(bts)
|
bts := make([]byte, 128*format.KibiByte)
|
||||||
|
|
||||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
var s uint64
|
||||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
for s < t.Size() {
|
||||||
if err != nil || n != len(buf) {
|
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||||
return errors.New("read failed")
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tts {
|
||||||
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
s += uint64(n)
|
||||||
|
|
||||||
|
if params.Progress != nil {
|
||||||
|
done := doneBytes.Add(uint64(n))
|
||||||
|
params.Progress(float32(done) / float32(totalBytes))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if g.Wait() != nil {
|
// start a goroutine to cancel the errgroup if the parent context is done
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
g.Go(func() error {
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,7 +398,6 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||||
),
|
),
|
||||||
input: deviceBufferTypes[input.d],
|
input: deviceBufferTypes[input.d],
|
||||||
output: deviceBufferTypes[output.d],
|
|
||||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
@@ -457,19 +478,6 @@ func (c Context) Input() ml.Context {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Output() ml.Context {
|
|
||||||
if c.b.output != nil {
|
|
||||||
return &Context{
|
|
||||||
b: c.b,
|
|
||||||
ctx: c.ctx,
|
|
||||||
buft: c.b.output,
|
|
||||||
maxGraphNodes: c.maxGraphNodes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Context) Layer(i int) ml.Context {
|
func (c Context) Layer(i int) ml.Context {
|
||||||
if buft, ok := c.b.layers[i]; ok {
|
if buft, ok := c.b.layers[i]; ok {
|
||||||
return &Context{
|
return &Context{
|
||||||
|
|||||||
@@ -61,11 +61,13 @@
|
|||||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
#elif defined(RDNA3)
|
#elif defined(RDNA3) || defined(RDNA4)
|
||||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
int tmp1;
|
int tmp1;
|
||||||
|
|||||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|||||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
|||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
|||||||
|
|
||||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
constexpr int nwarps = 1;
|
constexpr int nwarps = 1;
|
||||||
constexpr int rows_per_cuda_block = 1;
|
constexpr int rows_per_cuda_block = 1;
|
||||||
#else
|
#else
|
||||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
|||||||
@@ -150,6 +150,10 @@
|
|||||||
#define CDNA
|
#define CDNA
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
#define RDNA4
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
defined(__gfx1150__) || defined(__gfx1151__)
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
#define RDNA3
|
#define RDNA3
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package input
|
package input
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
||||||
// Input represents one token in the input stream
|
// Input represents one token in the input stream
|
||||||
type Input struct {
|
type Input struct {
|
||||||
// Token is a single element of text.
|
// Token is a single element of text.
|
||||||
@@ -33,11 +35,24 @@ type MultimodalIndex struct {
|
|||||||
Multimodal any
|
Multimodal any
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options contains the inputs for a model forward pass
|
// Batch contains the inputs for a model forward pass
|
||||||
type Options struct {
|
type Batch struct {
|
||||||
Inputs []int32
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
|
Inputs ml.Tensor
|
||||||
|
|
||||||
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
|
// models or for batches without multimodal elements.
|
||||||
Multimodal []MultimodalIndex
|
Multimodal []MultimodalIndex
|
||||||
|
|
||||||
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
|
// in length to Inputs.
|
||||||
Positions []int32
|
Positions []int32
|
||||||
|
|
||||||
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
|
// be returned.
|
||||||
Outputs []int32
|
Outputs []int32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
|||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||||
|
|
||||||
Backend() ml.Backend
|
Backend() ml.Backend
|
||||||
Config() config
|
Config() config
|
||||||
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
r, err := os.Open(modelPath)
|
r, err := os.Open(modelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
b, err := ml.NewBackend(r, params)
|
b, err := ml.NewBackend(ctx, r, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(opts.Positions) != len(opts.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.Positions) < 1 {
|
if len(batch.Positions) < 1 {
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, opts)
|
t, err := m.Forward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
|
|||||||
|
|
||||||
type notTextProcessorModel struct{}
|
type notTextProcessorModel struct{}
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
if len(m.Layers) == gemma27BLayerCount {
|
||||||
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
for _, image := range opts.Multimodal {
|
for _, image := range batch.Multimodal {
|
||||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|||||||
@@ -135,32 +135,27 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return inputs, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
crossAttentionStates = images[len(images)-1]
|
crossAttentionStates = images[len(images)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ type InputCache struct {
|
|||||||
cache kvcache.Cache
|
cache kvcache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||||
if kvSize/int32(numSlots) < 1 {
|
numCtx := kvSize / int32(numSlots)
|
||||||
|
|
||||||
|
if numCtx < 1 {
|
||||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
|
|
||||||
cache := model.Config().Cache
|
cache := model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &InputCache{
|
return &InputCache{
|
||||||
numCtx: kvSize / int32(numSlots),
|
numCtx: numCtx,
|
||||||
enabled: cache != nil,
|
enabled: cache != nil,
|
||||||
slots: slots,
|
slots: slots,
|
||||||
multiUserCache: multiUserCache,
|
multiUserCache: multiUserCache,
|
||||||
|
|||||||
@@ -348,7 +348,8 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var options input.Options
|
var batchInputs []int32
|
||||||
|
var batch input.Batch
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
@@ -395,17 +396,17 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
batchInputs = append(batchInputs, inp.Token)
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(options.Outputs)
|
seq.iBatch = len(batch.Outputs)
|
||||||
if j+1 == len(seq.inputs) {
|
if j+1 == len(seq.inputs) {
|
||||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
@@ -413,14 +414,14 @@ func (s *Server) processBatch() error {
|
|||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.Inputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
ctx := s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
return fmt.Errorf("failed to decode batch: %w", err)
|
||||||
}
|
}
|
||||||
@@ -460,7 +461,7 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(options.Outputs)
|
vocabSize := len(logits) / len(batch.Outputs)
|
||||||
|
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -677,6 +678,7 @@ func (m *multiLPath) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
|
ctx context.Context,
|
||||||
mpath string,
|
mpath string,
|
||||||
params ml.BackendParams,
|
params ml.BackendParams,
|
||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
@@ -686,7 +688,7 @@ func (s *Server) loadModel(
|
|||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) {
|
||||||
var err error
|
var err error
|
||||||
s.model, err = model.New(mpath, params)
|
s.model, err = model.New(ctx, mpath, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -698,7 +700,7 @@ func (s *Server) loadModel(
|
|||||||
panic("loras are not yet implemented")
|
panic("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -782,6 +784,9 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
|
Progress: func(progress float32) {
|
||||||
|
server.progress = progress
|
||||||
|
},
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
@@ -790,13 +795,13 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
go server.run(ctx)
|
go server.run(ctx)
|
||||||
|
|
||||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Sampler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
}
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@@ -94,13 +98,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
var r float32
|
||||||
if s.rng != nil {
|
if s.rng != nil {
|
||||||
r = s.rng.Float32()
|
r = s.rng.Float32()
|
||||||
@@ -123,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return 1
|
return 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if math.IsNaN(float64(sum)) {
|
||||||
|
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||||
|
}
|
||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
|||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test very high p
|
||||||
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
|
// Use extremely small topP to filter out all tokens
|
||||||
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Should get the token with the highest logit
|
||||||
|
want = int32(0)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
|
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got %d", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
|||||||
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
|
|||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Test with very high p value
|
||||||
tokens = topP(tokens, 0.95)
|
got := topP(tokens, 1.0)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
// Should keep all tokens since p is 1
|
||||||
if len(tokens) > 3 {
|
if len(got) != len(input) {
|
||||||
|
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = topP(tokens, 0.95)
|
||||||
|
|
||||||
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test edge case - ensure at least one token remains
|
// Test edge case - ensure at least one token remains
|
||||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
input = []float32{-1e6, -1e6, -1e7}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topP(tokens, 0.0) // Very small p
|
got = topP(tokens, 0.0)
|
||||||
if len(tokens) < 1 {
|
if len(got) < 1 {
|
||||||
t.Error("topP should keep at least one token")
|
t.Error("topP should keep at least one token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 1e-10)
|
||||||
|
if len(got) == 0 {
|
||||||
|
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
|
|||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with single token
|
||||||
|
tokens = toTokens(input[:1])
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.1)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
input = []float32{1e-10, 1e-10, 1e-10}
|
input = []float32{1e-10, 1e-10, 1e-10}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = minP(tokens, 1.0)
|
tokens = minP(tokens, 1.0)
|
||||||
if len(tokens) < 1 {
|
if len(tokens) < 1 {
|
||||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||||
}
|
got := minP(tokens, 1.0)
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
// Test with normal p value
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
got = minP(tokens, 0.2)
|
||||||
tokens := toTokens(input)
|
|
||||||
|
|
||||||
tokens = topK(tokens, 20)
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
|
if len(got) > 3 {
|
||||||
for i := 1; i < len(tokens); i++ {
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
if tokens[i].value > tokens[i-1].value {
|
t.Logf("got: %v", got)
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
// Test with zero p value
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
got = minP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != len(tokens) {
|
||||||
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -60,6 +59,11 @@ var (
|
|||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
ErrCached = errors.New("cached")
|
ErrCached = errors.New("cached")
|
||||||
|
|
||||||
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||||
|
// incomplete due to one or more layer download failures. Users that
|
||||||
|
// want specific errors should use [WithTrace].
|
||||||
|
ErrIncomplete = errors.New("incomplete")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults
|
// Defaults
|
||||||
@@ -213,12 +217,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@@ -278,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
|
|
||||||
func UserAgent() string {
|
func UserAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
version := buildinfo.Main.Version
|
||||||
|
if version == "(devel)" {
|
||||||
|
// When using `go run .` the version is "(devel)". This is seen
|
||||||
|
// as an invalid version by ollama.com and so it defaults to
|
||||||
|
// "needs upgrade" for some requests, such as pulls. These
|
||||||
|
// checks can be skipped by using the special version "v0.0.0",
|
||||||
|
// so we set it to that here.
|
||||||
|
version = "v0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||||
buildinfo.Main.Version,
|
version,
|
||||||
runtime.GOARCH,
|
runtime.GOARCH,
|
||||||
runtime.GOOS,
|
runtime.GOOS,
|
||||||
runtime.Version(),
|
runtime.Version(),
|
||||||
@@ -425,13 +434,14 @@ func canRetry(err error) bool {
|
|||||||
//
|
//
|
||||||
// It always calls update with a nil error.
|
// It always calls update with a nil error.
|
||||||
type trackingReader struct {
|
type trackingReader struct {
|
||||||
|
l *Layer
|
||||||
r io.Reader
|
r io.Reader
|
||||||
n *atomic.Int64
|
update func(l *Layer, n int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||||
n, err = r.r.Read(p)
|
n, err = r.r.Read(p)
|
||||||
r.n.Add(int64(n))
|
r.update(r.l, int64(n), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,6 +457,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@@ -456,11 +471,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@@ -468,45 +479,52 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
|
var expected int64
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
expected += l.Size
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var received atomic.Int64
|
||||||
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
received.Add(l.Size)
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, progress.Load(), err)
|
// Chunksum stream interrupted. Note in trace
|
||||||
|
// log and let in-flight downloads complete.
|
||||||
|
// This will naturally trigger ErrIncomplete
|
||||||
|
// since received < expected bytes.
|
||||||
|
t.update(l, 0, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
received.Add(cs.Chunk.Size())
|
||||||
if err != nil {
|
} else {
|
||||||
return err
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
err := func() error {
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -518,49 +536,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
// Count bytes towards
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||||
// progress, as they arrive, so
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
// that our bytes piggyback
|
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close writer immediately after downloads finish, not at Pull
|
||||||
|
// exit. Using defer would keep file descriptors open until all
|
||||||
|
// layers complete, potentially exhausting system limits with
|
||||||
|
// many layers.
|
||||||
|
//
|
||||||
|
// The WaitGroup tracks when all chunks finish downloading,
|
||||||
|
// allowing precise writer closure in a background goroutine.
|
||||||
|
// Each layer briefly uses one extra goroutine while at most
|
||||||
|
// maxStreams()-1 chunks download in parallel.
|
||||||
|
//
|
||||||
|
// This caps file descriptors at maxStreams() instead of
|
||||||
|
// growing with layer count.
|
||||||
|
g.Go(func() error {
|
||||||
|
wg.Wait()
|
||||||
|
chunked.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if received.Load() != expected {
|
||||||
|
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,6 +25,28 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ExampleRegistry_cancelOnFirstError() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
// Discontinue pulling layers if there is an
|
||||||
|
// error instead of continuing to pull more
|
||||||
|
// data.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var r Registry
|
||||||
|
if err := r.Pull(ctx, "model"); err != nil {
|
||||||
|
// panic for demo purposes
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestMarshalJSON(t *testing.T) {
|
func TestManifestMarshalJSON(t *testing.T) {
|
||||||
// All manifests should contain an "empty" config object.
|
// All manifests should contain an "empty" config object.
|
||||||
var m Manifest
|
var m Manifest
|
||||||
@@ -70,7 +93,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@@ -88,7 +111,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,3 +790,79 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPullChunksums(t *testing.T) {
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
content := "hello"
|
||||||
|
var chunksums string
|
||||||
|
contentDigest := func() blob.Digest {
|
||||||
|
return blob.DigestFromBytes(content)
|
||||||
|
}
|
||||||
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||||
|
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||||
|
w.Header().Set("Content-Location", loc)
|
||||||
|
io.WriteString(w, chunksums)
|
||||||
|
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||||
|
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected request: %v", r)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||||
|
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Logf("Update: %v %d %v", l, n, err)
|
||||||
|
mu.Lock()
|
||||||
|
reads = append(reads, n)
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||||
|
blob.DigestFromBytes("hel"),
|
||||||
|
blob.DigestFromBytes("lo"),
|
||||||
|
)
|
||||||
|
err := rc.Pull(ctx, "test")
|
||||||
|
check(err)
|
||||||
|
wantReads := []int64{
|
||||||
|
0, // initial signaling of layer pull starting
|
||||||
|
3, // first chunk read
|
||||||
|
2, // second chunk read
|
||||||
|
}
|
||||||
|
if !slices.Equal(reads, wantReads) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||||
|
}
|
||||||
|
|
||||||
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
check(err)
|
||||||
|
mg, err := rc.ResolveLocal("test")
|
||||||
|
check(err)
|
||||||
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
|
}
|
||||||
|
for i := range mg.Layers {
|
||||||
|
_, err = c.Get(mg.Layers[i].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing chunks
|
||||||
|
content = "llama"
|
||||||
|
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||||
|
err = rc.Pull(ctx, "missingchunks")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error because of missing chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ type params struct {
|
|||||||
//
|
//
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||||
// defined to default to true if not present, so we need a way to check
|
// defined to default to true if not present, so we need a way to check
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
// if the client decisively set it to false. So, we use a pointer to a
|
||||||
// bool. Gross.
|
// bool. Gross.
|
||||||
//
|
//
|
||||||
// Use [stream()] to get the correct value for this field.
|
// Use [stream()] to get the correct value for this field.
|
||||||
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
progress := make(map[*ollama.Layer]int64)
|
progress := make(map[*ollama.Layer]int64)
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||||
pushUpdate := func() {
|
flushProgress := func() {
|
||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): Flushing every layer in one update doesn't
|
||||||
// needing to flush out them all in one big update. We _could_
|
// scale well. We could flush only the modified layers or track
|
||||||
// just flush on the changed ones, or just track the whole
|
// the full download. Needs further consideration, though it's
|
||||||
// download. Needs more thought. This is fine for now.
|
// fine for now.
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
maps.Copy(progressCopy, progress)
|
maps.Copy(progressCopy, progress)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
for l, n := range progress {
|
for l, n := range progressCopy {
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer flushProgress()
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
|
||||||
start := sync.OnceFunc(func() {
|
start := sync.OnceFunc(func() {
|
||||||
pushUpdate()
|
flushProgress() // flush initial state
|
||||||
t.Reset(100 * time.Millisecond)
|
t.Reset(100 * time.Millisecond)
|
||||||
})
|
})
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
start() // flush initial state
|
// Block flushing progress updates until every
|
||||||
|
// layer is accounted for. Clients depend on a
|
||||||
|
// complete model size to calculate progress
|
||||||
|
// correctly; if they use an incomplete total,
|
||||||
|
// progress indicators would erratically jump
|
||||||
|
// as new layers are registered.
|
||||||
|
start()
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
progress[l] = n
|
progress[l] += n
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var status string
|
var status string
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if !envconfig.SchedSpread() {
|
if !envconfig.SchedSpread() {
|
||||||
for _, g := range sgl {
|
for _, g := range sgl {
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return []discover.GpuInfo{g}
|
return []discover.GpuInfo{g}
|
||||||
@@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
// Now try all the GPUs
|
// Now try all the GPUs
|
||||||
for _, p := range numParallelToTry {
|
for _, p := range numParallelToTry {
|
||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return sgl
|
return sgl
|
||||||
@@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
|
|||||||
var bestEstimate uint64
|
var bestEstimate uint64
|
||||||
var bestFit int
|
var bestFit int
|
||||||
for i, gl := range byLibrary {
|
for i, gl := range byLibrary {
|
||||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||||
if estimatedVRAM > bestEstimate {
|
if estimatedVRAM > bestEstimate {
|
||||||
bestEstimate = estimatedVRAM
|
bestEstimate = estimatedVRAM
|
||||||
bestFit = i
|
bestFit = i
|
||||||
@@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) {
|
|||||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<start_of_turn>user
|
||||||
|
{{- if and (eq $i 1) $.System }}
|
||||||
|
{{ $.System }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ end }}
|
||||||
|
{{- if $last }}<start_of_turn>model
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<end_of_turn>"
|
||||||
|
],
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
@@ -87,6 +87,10 @@
|
|||||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"name": "gemma-instruct"
|
"name": "gemma-instruct"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||||
|
"name": "gemma3-instruct"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
"name": "llama3-instruct"
|
"name": "llama3-instruct"
|
||||||
|
|||||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
Reference in New Issue
Block a user