Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-03-14 13:44:39 +08:00
committed by GitHub
27 changed files with 861 additions and 591 deletions

View File

@@ -349,6 +349,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
}
@@ -467,6 +468,13 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View File

@@ -18,6 +18,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync/atomic"
@@ -568,8 +569,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
if boolErr != nil {
return errors.New("error retrieving flags")
}
@@ -607,7 +609,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
}
req := api.ShowRequest{Name: args[0]}
req := api.ShowRequest{Name: args[0], Verbose: verbose}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
@@ -630,10 +632,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil
}
return showInfo(resp, os.Stdout)
return showInfo(resp, verbose, os.Stdout)
}
func showInfo(resp *api.ShowResponse, w io.Writer) error {
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w)
@@ -690,6 +692,45 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
})
}
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -1196,6 +1237,7 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",

View File

@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -68,6 +68,56 @@ func TestShowInfo(t *testing.T) {
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -89,7 +139,7 @@ func TestShowInfo(t *testing.T) {
stop you
stop up
temperature 99`,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -126,7 +176,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -159,7 +209,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey!
Weigh anchor!
`,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -188,7 +238,7 @@ Weigh anchor!
QuantizationLevel: "FP16",
},
License: license,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}

View File

@@ -347,7 +347,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] {
case "info":
_ = showInfo(resp, os.Stderr)
_ = showInfo(resp, false, os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")

View File

@@ -87,7 +87,7 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow

View File

@@ -187,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored?

View File

@@ -327,6 +327,10 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface {
Name() string
Decode(io.ReadSeeker) (model, error)
@@ -579,39 +583,52 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*kv("num_channels")*maxNumTiles +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
}
return weights, graphSize
}

View File

@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
}
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),

View File

@@ -1,4 +1,5 @@
#include <string.h>
#include <inttypes.h>
#include "ollama-debug.h"
@@ -24,7 +25,7 @@ static void print_tensor(const void *tensor, void (*cb)(const void *, int),
fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems;
if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip;
@@ -67,7 +68,7 @@ static void print_tensor_i32(const void *tensor, int i) {
}
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]);

View File

@@ -22,6 +22,8 @@ import (
"github.com/ollama/ollama/model/input"
)
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error)

View File

@@ -84,6 +84,10 @@ func New(c ml.Config) (model.Model, error) {
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err

View File

@@ -15,7 +15,6 @@ type TextOptions struct {
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
finalLogitSoftcap float32
largeModelScaling bool
}
@@ -57,16 +56,15 @@ func newTextModel(c ml.Config) *TextModel {
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
},
}
@@ -245,10 +243,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -63,6 +63,10 @@ func New(c ml.Config) (model.Model, error) {
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err

View File

@@ -116,19 +116,9 @@ func (i *Instance) Readline() (string, error) {
switch r {
case KeyUp:
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
i.historyPrev(buf, &currentLineBuf)
case KeyDown:
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(currentLineBuf)
}
}
i.historyNext(buf, &currentLineBuf)
case KeyLeft:
buf.MoveLeft()
case KeyRight:
@@ -185,6 +175,10 @@ func (i *Instance) Readline() (string, error) {
esc = true
case CharInterrupt:
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
case CharNext:
i.historyNext(buf, &currentLineBuf)
case CharLineStart:
buf.MoveToStart()
case CharLineEnd:
@@ -246,6 +240,24 @@ func (i *Instance) HistoryDisable() {
i.History.Enabled = false
}
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
*currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
}
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(*currentLineBuf)
}
}
}
func NewTerminal() (*Terminal, error) {
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)

View File

@@ -691,65 +691,6 @@ type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
@@ -927,9 +868,13 @@ func Execute(args []string) error {
defer listener.Close()
mux := http.NewServeMux()
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
// TODO: support embeddings
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)
httpServer := http.Server{
Handler: mux,

View File

@@ -84,14 +84,11 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return greedy(tokens), nil
}
if s.topK > 0 {
tokens = topK(tokens, s.topK)
} else {
sortLogits(tokens)
}
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
// token logit values are updated to probabilities
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)

View File

@@ -1,12 +1,42 @@
package sample
import (
"container/heap"
"math"
"slices"
)
// temperature applies scaling and softmax to the logits
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token
func (h tokenHeap) Len() int { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *tokenHeap) Push(x any) {
*h = append(*h, x.(token))
}
func (h *tokenHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
@@ -15,15 +45,14 @@ func temperature(ts []token, temp float32) []token {
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
// Compute exp(x - max)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// Normalize
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}
@@ -31,62 +60,42 @@ func temperature(ts []token, temp float32) []token {
return ts
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
// The heap is represented as an array where for any node at index i:
// - Left child is at index 2i + 1
// - Right child is at index 2i + 2
// - Parent is at index (i-1)/2
//
// The function compares a node with its children and:
// 1. Finds the smallest value between the node and its children
// 2. If the node is not the smallest, swaps it with its smallest child
// 3. Continues this process down the affected path until the min-heap property is restored
func siftDown(data []token, start, end int) {
root := start
for {
child := 2*root + 1
if child >= end {
break
}
// Find smaller child (we want min heap)
if child+1 < end && data[child+1].value < data[child].value {
child++
}
// Exit if root is already smaller than children
if data[root].value <= data[child].value {
break
}
// Swap with smaller child and continue
data[root], data[child] = data[child], data[root]
root = child
}
}
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) {
if k >= len(ts) || k <= 0 {
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
return ts
}
// Heapify + siftDown - O(nlog(k))
// Build min-heap of first k elements
heap := ts[:k]
for i := k/2 - 1; i >= 0; i-- {
siftDown(heap, i, k)
}
// Process remaining elements - if larger than heap root, replace root
// Initialize min-heap with first k elements
h := make(tokenHeap, k)
copy(h, ts[:k])
heap.Init(&h)
// Process remaining elements
for i := k; i < len(ts); i++ {
if ts[i].value > heap[0].value {
heap[0] = ts[i]
siftDown(heap, 0, k)
if ts[i].value > h[0].value {
heap.Pop(&h)
heap.Push(&h, ts[i])
}
}
slices.Reverse(heap)
// Convert heap to sorted slice in descending order
result := make([]token, len(h))
for i := k - 1; i >= 0; i-- {
result[i] = heap.Pop(&h).(token)
}
ts = heap
return ts
return result
}
// topP limits tokens to those with cumulative probability p
@@ -134,62 +143,3 @@ func minP(ts []token, p float32) []token {
ts = validTokens
return ts
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// sortLogits sorts implementation to sort tokens by logits using counting sort
// counting sort is faster than built-in sort for this use case
func sortLogits(tokens []token) {
if len(tokens) <= 1 {
return
}
// Find max/min in a single pass
minLogit, maxLogit := tokens[0].value, tokens[0].value
for _, t := range tokens[1:] {
if t.value < minLogit {
minLogit = t.value
} else if t.value > maxLogit {
maxLogit = t.value
}
}
// Calculate scaling to map to uint32 range
logitRange := maxLogit - minLogit
if logitRange < 1e-6 {
return // All values effectively equal
}
// Count frequencies directly from tokens
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
var counts [256]int // For first byte
// First pass: count frequencies
for _, t := range tokens {
// Map to [0, maxInt] range
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
counts[score>>16]++
}
// Calculate offsets
var offset int
for i := range counts {
count := counts[i]
counts[i] = offset
offset += count
}
// Second pass: place elements in correct position
output := make([]token, len(tokens))
// Track current positions
countsCopy := counts
for i, t := range tokens {
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
pos := countsCopy[score>>16]
countsCopy[score>>16]++
output[len(tokens)-1-pos] = tokens[i]
}
copy(tokens, output)
}

View File

@@ -6,80 +6,155 @@ import (
"testing"
)
// Helper to convert float64 slice to logit slice
func toTokens(values []float64) []token {
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
tokens := make([]token, len(values))
for i, v := range values {
tokens[i] = token{
id: int32(i),
value: float32(v),
value: v,
}
}
return tokens
}
// Helper to compare logit slices
func compareLogits(t *testing.T, name string, want []float64, got []token) {
func compareLogits(t *testing.T, name string, want []float32, got []token) {
t.Helper()
if len(want) != len(got) {
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
return
}
for i := range want {
if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
}
}
}
func TestTemperatureAndSoftmax(t *testing.T) {
input := []float64{1, 4, -2, 0}
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
got := temperature(toTokens(input), 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got)
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
got = temperature(toTokens(input), 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got)
got = temperature(toTokens(input), 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got)
}
func TestSoftmax(t *testing.T) {
tests := []struct {
name string
input []float32
expected []float32
}{
{
name: "correctness softmax",
input: []float32{1, -2, 3, 0},
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
},
{
name: "normal distribution",
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
},
{
name: "single value",
input: []float32{1.0},
},
{
name: "identical values",
input: []float32{0.9, 0.9, 0.9},
},
{
name: "large values",
input: []float32{1000.0, 2000.0, 3000.0},
},
{
name: "small values",
input: []float32{1e-6, 2e-6, 3e-6},
},
{
name: "negative values",
input: []float32{-1.0, -2.0, -3.0},
},
{
name: "mixed values",
input: []float32{-100.0, 0.0, 100.0},
},
}
got = temperature(toTokens(input), 1)
// Check probabilities sum to 1
sum = 0.0
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := softmax(toTokens(tt.input))
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
}
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
})
}
}
func TestTopK(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
// Test k=3
got := topK(toTokens(input), 3)
if len(got) != 3 {
t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
// Test k=5
got := topK(toTokens(input), 5)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
}
// Should keep highest 3 values: 4, 2, 1
want := []float64{4, 2, 1}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got)
// Test k > len
got = topK(toTokens(input), 10)
compareLogits(t, "topK(10)", input, got)
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
}
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
}
func TestTopP(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
input := []float32{-3, -2, -1, 0, 1, 2, 4}
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
tokens = temperature(tokens, 1)
sortLogits(tokens)
tokens = softmax(tokens)
tokens = topK(tokens, 20)
// Then apply topP
got := topP(tokens, 0.95)
@@ -92,11 +167,11 @@ func TestTopP(t *testing.T) {
}
func TestMinP(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)
@@ -108,10 +183,10 @@ func TestMinP(t *testing.T) {
}
func TestSortLogits(t *testing.T) {
input := []float64{3, 1, 4, 2, -1, 0, -2}
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
sortLogits(tokens)
tokens = topK(tokens, 20)
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
@@ -120,7 +195,7 @@ func TestSortLogits(t *testing.T) {
}
}
want := []float64{4, 3, 2, 1, 0, -1, -2}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
}
@@ -144,6 +219,14 @@ func BenchmarkTransforms(b *testing.B) {
}
})
b.Run("Softmax", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
softmax(tokensCopy)
}
})
b.Run("TopK", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
@@ -172,7 +255,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
sortLogits(tokensCopy)
topK(tokensCopy, 200000)
}
})
}

View File

@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>@<digest>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name)
if err != nil {
return err
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
// Links returns a sequence of link names. The sequence is in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
@@ -414,12 +402,14 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
d Digest
f *os.File
err error
h hash.Hash
w io.Writer // underlying writer; set by creator
n int64
err error
testHookBeforeFinalWrite func(*os.File)
}
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
n, err := w.w.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
@@ -493,10 +487,12 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
d: out,
size: size,
h: sha256.New(),
f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
panic(err) // this should never happen
}
return abs
}

66
server/internal/cache/blob/chunked.go vendored Normal file
View File

@@ -0,0 +1,66 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
"github.com/ollama/ollama/server/internal/chunks"
)
type Chunk = chunks.Chunk // TODO: move chunks here?
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

View File

@@ -63,6 +63,10 @@ func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}

View File

@@ -31,18 +31,21 @@ func ParseRange(s string) (unit string, _ Chunk, _ error) {
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
func Parse[S ~string | ~[]byte](s S) (Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
}
end, err := strconv.ParseInt(endStr, 10, 64)
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
return Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
}
end, err := strconv.ParseInt(endPart, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
return Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
}
return Chunk{start, end}, nil
}

View File

@@ -19,6 +19,7 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
@@ -38,7 +39,6 @@ import (
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed"
)
@@ -66,12 +66,7 @@ var (
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
DefaultChunkingThreshold = 64 << 20
)
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
@@ -211,8 +206,7 @@ type Registry struct {
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// Clients that want "unlimited" streams should set this to a large
// number.
// A negative value means no limit.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
@@ -282,24 +276,13 @@ func DefaultRegistry() (*Registry, error) {
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
@@ -426,6 +409,21 @@ func canRetry(err error) bool {
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
r io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.n.Add(int64(n))
return
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
@@ -434,11 +432,6 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name)
if err != nil {
return err
@@ -457,126 +450,95 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err == nil && info.Size == l.Size
}
t := traceFromContext(ctx)
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
for _, l := range layers {
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
t := traceFromContext(ctx)
skip := make([]bool, len(layers))
for i, l := range layers {
t.update(l, 0, nil)
if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached)
}
}
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
for i, l := range layers {
if skip[i] {
continue
}
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
defer chunked.Close()
t.update(l, 0, nil)
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
// TODO(bmizerany): retry/backoff like below in
// the chunking case
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
})
} else {
q := syncs.NewRelayReader()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
t.update(l, progress.Load(), err)
break
}
g.Go(func() (err error) {
defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
})
defer func() { t.update(l, progress.Load(), err) }()
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
wp := writerPool{size: r.maxChunkSize()}
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take()
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := wp.get()
tw.Reset(ticket)
defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
return nil
})
}
err := func() error {
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", cs.Chunk))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards
// progress, as they arrive, so
// that our bytes piggyback
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
}
if err := g.Wait(); err != nil {
return err
}
@@ -615,8 +577,6 @@ type Manifest struct {
Config *Layer `json:"config"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given
// digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer {
@@ -643,10 +603,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config *Layer `json:"config"`
Config Layer `json:"config"`
}{
M: M(m),
Config: &Layer{Digest: emptyDigest},
M: M(m),
}
return json.Marshal(v)
}
@@ -736,6 +695,123 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
return m, nil
}
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
//
// Example:
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
//
// The blobURL is the URL to download the chunks from.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := chunks.Parse(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
@@ -898,13 +974,6 @@ func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
type publicError struct {
wrapped error
message string
@@ -990,28 +1059,3 @@ func splitExtended(s string) (scheme, name, digest string) {
}
return scheme, s, digest
}
type writerPool struct {
size int64 // set by the caller
mu sync.Mutex
ws []*bufio.Writer
}
func (p *writerPool) get() *bufio.Writer {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.ws) == 0 {
return bufio.NewWriterSize(nil, int(p.size))
}
w := p.ws[len(p.ws)-1]
p.ws = p.ws[:len(p.ws)-1]
return w
}
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
}

View File

@@ -428,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{6}
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
@@ -532,6 +532,8 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
}
func TestRegistryPullChunking(t *testing.T) {
t.Skip("TODO: BRING BACK BEFORE LANDING")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {

View File

@@ -1,6 +1,5 @@
// Package registry provides an http.Handler for handling local Ollama API
// requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
// Package registry implements an http.Handler for handling local Ollama API
// model management requests. See [Local] for details.
package registry
import (
@@ -10,6 +9,7 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"net/http"
"sync"
"time"
@@ -18,16 +18,11 @@ import (
"github.com/ollama/ollama/server/internal/client/ollama"
)
// Local is an http.Handler for handling local Ollama API requests for
// performing tasks related to the ollama.com model registry combined with the
// local disk cache.
// Local implements an http.Handler for handling local Ollama API model
// management requests, such as pushing, pulling, and deleting models.
//
// It is not concern of Local, or this package, to handle model creation, which
// proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
// It can be arranged for all unknown requests to be passed through to a
// fallback handler, if one is provided.
type Local struct {
Client *ollama.Registry // required
Logger *slog.Logger // required
@@ -63,6 +58,7 @@ func (e serverError) Error() string {
var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"}
errModelNotFound = &serverError{404, "not_found", "model not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"}
)
@@ -175,8 +171,16 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
}
type params struct {
DeprecatedName string `json:"name"` // Use [params.model]
Model string `json:"model"` // Use [params.model]
// DeprecatedName is the name of the model to push, pull, or delete,
// but is deprecated. New clients should use [Model] instead.
//
// Use [model()] to get the model name for both old and new API requests.
DeprecatedName string `json:"name"`
// Model is the name of the model to push, pull, or delete.
//
// Use [model()] to get the model name for both old and new API requests.
Model string `json:"model"`
// AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately.
@@ -189,9 +193,18 @@ type params struct {
// confusing flags such as this.
AllowNonTLS bool `json:"insecure"`
// ProgressStream is a flag that indicates the client is expecting a stream of
// progress updates.
ProgressStream bool `json:"stream"`
// Stream, if true, will make the server send progress updates in a
// streaming of JSON objects. If false, the server will send a single
// JSON object with the final status as "success", or an error object
// if an error occurred.
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
Stream *bool `json:"stream"`
}
// model returns the model name for both old and new API requests.
@@ -199,6 +212,13 @@ func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName)
}
func (p params) stream() bool {
if p.Stream == nil {
return true
}
return *p.Stream
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
@@ -212,16 +232,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return err
}
if !ok {
return &serverError{404, "not_found", "model not found"}
return errModelNotFound
}
if s.Prune == nil {
return nil
if s.Prune != nil {
return s.Prune()
}
return s.Prune()
return nil
}
type progressUpdateJSON struct {
Status string `json:"status"`
Status string `json:"status,omitempty,omitzero"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
@@ -237,6 +257,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return err
}
enc := json.NewEncoder(w)
if !p.stream() {
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return errModelNotFound
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
}
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
@@ -246,69 +277,67 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
defer maybeFlush()
var mu sync.Mutex
enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
progress := make(map[*ollama.Layer]int64)
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
mu.Lock()
defer mu.Unlock()
progressCopy := make(map[*ollama.Layer]int64, len(progress))
pushUpdate := func() {
defer maybeFlush()
// TODO(bmizerany): coalesce these updates; writing per
// update is expensive
// TODO(bmizerany): This scales poorly with more layers due to
// needing to flush out them all in one big update. We _could_
// just flush on the changed ones, or just track the whole
// download. Needs more thought. This is fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progress {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Status: "pulling",
Total: l.Size,
Completed: n,
})
}
}
t := time.NewTicker(time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
pushUpdate()
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
start() // flush initial state
}
mu.Lock()
progress[l] = n
mu.Unlock()
},
})
done := make(chan error, 1)
go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model())
}()
func() {
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
mu.Lock()
maybeFlush()
mu.Unlock()
case err := <-done:
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
for {
select {
case <-t.C:
pushUpdate()
case err := <-done:
pushUpdate()
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
} else {
status = fmt.Sprintf("error: %v", err)
}
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
enc.Encode(progressUpdateJSON{Status: status})
}
return nil
}
}()
return nil
}
}
func decodeUserJSON[T any](r io.Reader) (T, error) {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"net"
@@ -160,7 +159,6 @@ var registryFS = sync.OnceValue(func() fs.FS {
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
@@ -179,7 +177,7 @@ func TestServerPull(t *testing.T) {
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving file: %s", r.URL.Path)
t.Logf("serving blob: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
@@ -188,7 +186,7 @@ func TestServerPull(t *testing.T) {
t.Helper()
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
t.Errorf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines)
@@ -197,35 +195,29 @@ func TestServerPull(t *testing.T) {
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Fatalf("! missing %q in body", want)
t.Errorf("! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Fatalf("! unexpected %q in body", want)
t.Errorf("! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"}
`)
@@ -240,19 +232,39 @@ func TestServerPull(t *testing.T) {
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""}
!verifying
!writing
!success
`)
// Non-streaming pulls
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
checkResponse(got, `
{"status":"success"}
!digest
!total
!completed
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
checkErrorResponse(t, got, 404, "not_found", "model not found")
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
var fellback bool
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fellback = true
})
got = s.send(t, "DELETE", "/api/unknown", `{}`)
if !fellback {
t.Fatal("expected Fallback to be called")
}
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
}
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {

View File

@@ -435,7 +435,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
kvData, err := getKVData(m.ModelPath, false)
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -483,8 +483,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
if err := g.Wait(); err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return
}
@@ -545,8 +544,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return
}
@@ -850,16 +848,23 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
kvData, err := getKVData(m.ModelPath, req.Verbose)
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
}
delete(kvData, "general.name")
delete(kvData, "tokenizer.chat_template")
resp.ModelInfo = kvData
tensorData := make([]api.Tensor, len(tensors.Items()))
for cnt, t := range tensors.Items() {
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
}
resp.Tensors = tensorData
if len(m.ProjectorPaths) > 0 {
projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
if err != nil {
return nil, err
}
@@ -869,17 +874,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
func getKVData(digest string, verbose bool) (ggml.KV, error) {
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
maxArraySize := 0
if verbose {
maxArraySize = -1
}
kvData, err := llm.LoadModel(digest, maxArraySize)
data, err := llm.LoadModel(digest, maxArraySize)
if err != nil {
return nil, err
return nil, ggml.Tensors{}, err
}
kv := kvData.KV()
kv := data.KV()
if !verbose {
for k := range kv {
@@ -889,7 +894,7 @@ func getKVData(digest string, verbose bool) (ggml.KV, error) {
}
}
return kv, nil
return kv, data.Tensors(), nil
}
func (s *Server) ListHandler(c *gin.Context) {