Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-04-04 20:46:54 +08:00
committed by GitHub
79 changed files with 2805 additions and 1006 deletions

View File

@@ -307,6 +307,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
@@ -346,6 +347,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
@@ -417,6 +419,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
### Cloud
@@ -456,6 +459,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// StatusError is an error with an HTTP status code and message.
@@ -81,7 +82,7 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
Options map[string]any `json:"options"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -106,7 +107,7 @@ type ChatRequest struct {
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Options map[string]any `json:"options"`
}
type Tools []Tool
@@ -260,7 +261,7 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Options map[string]any `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
@@ -286,7 +287,7 @@ type EmbeddingRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Options map[string]any `json:"options"`
}
// EmbeddingResponse is the response from [Client.Embeddings].
@@ -332,7 +333,7 @@ type ShowRequest struct {
Template string `json:"template"`
Verbose bool `json:"verbose"`
Options map[string]interface{} `json:"options"`
Options map[string]any `json:"options"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -350,6 +351,7 @@ type ShowResponse struct {
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
}
@@ -503,7 +505,7 @@ func (m *Metrics) Summary() {
}
}
func (opts *Options) FromMap(m map[string]interface{}) error {
func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@@ -560,12 +562,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
}
field.SetString(val)
case reflect.Slice:
// JSON unmarshals to []interface{}, not []string
val, ok := val.([]interface{})
// JSON unmarshals to []any, not []string
val, ok := val.([]any)
if !ok {
return fmt.Errorf("option %q must be of type array", key)
}
// convert []interface{} to []string
// convert []any to []string
slice := make([]string, len(val))
for i, item := range val {
str, ok := item.(string)
@@ -672,7 +674,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
func FormatParams(params map[string][]string) (map[string]any, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@@ -686,7 +688,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
}
}
out := make(map[string]interface{})
out := make(map[string]any)
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {

View File

@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var oMap map[string]interface{}
var oMap map[string]any
err := json.Unmarshal([]byte(test.req), &oMap)
require.NoError(t, err)
opts := DefaultOptions()

View File

@@ -92,7 +92,7 @@ func BenchmarkColdStart(b *testing.B) {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
@@ -155,7 +155,7 @@ func warmup(client *api.Client, model string, prompt string, b *testing.B) {
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)

View File

@@ -18,6 +18,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"slices"
"sort"
"strconv"
"strings"
@@ -267,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Options: map[string]any{},
}
format, err := cmd.Flags().GetString("format")
@@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
@@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
return
})
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
for _, capability := range resp.Capabilities {
rows = append(rows, []string{"", capability.String()})
}
return
})
}
if resp.ProjectorInfo != nil {
tableRender("Projector", func() (rows [][]string) {
arch := resp.ProjectorInfo["general.architecture"].(string)
@@ -837,7 +852,7 @@ type runOptions struct {
Format string
System string
Images []api.ImageData
Options map[string]interface{}
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestShowInfo(t *testing.T) {
@@ -260,6 +261,34 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("capabilities", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
}, false, &b); err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture test \n" +
" parameters 7B \n" +
" quantization FP16 \n" +
"\n" +
" Capabilities\n" +
" vision \n" +
" tools \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {

View File

@@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
var conv ModelConverter
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
case "LlamaForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":

190
convert/convert_mistral.go Normal file
View File

@@ -0,0 +1,190 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3Model struct {
ModelParameters
ImageTokenIndex uint32 `json:"image_token_index"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
TextModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
if p.ProjectorHiddenAct != "" {
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
}
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3Model) Replacements() []string {
return []string{
"language_model.model.norm", "output_norm",
"language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
Pattern string
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{
{"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"*.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch},

View File

@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_sentencepiece_model_proto_goTypes = []interface{}{
var file_sentencepiece_model_proto_goTypes = []any{
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
return
}
if !protoimpl.UnsafeEnabled {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*TrainerSpec); i {
case 0:
return &v.state
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*NormalizerSpec); i {
case 0:
return &v.state
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*SelfTestData); i {
case 0:
return &v.state
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
switch v := v.(*ModelProto); i {
case 0:
return &v.state
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
switch v := v.(*SelfTestData_Sample); i {
case 0:
return &v.state
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
switch v := v.(*ModelProto_SentencePiece); i {
case 0:
return &v.state

View File

@@ -12,7 +12,7 @@ func IsNUMA() bool {
// numa support in llama.cpp is linux only
return false
}
ids := map[string]interface{}{}
ids := map[string]any{}
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
for _, packageId := range packageIds {
id, err := os.ReadFile(packageId)

View File

@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
if err != nil {
return nil, err
}
defer file.Close()
return linuxCPUDetails(file)
}
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
for id, s := range socketByID {
s.CoreCount = len(coreBySocket[id])
s.ThreadCount = 0
for _, tc := range threadsByCoreBySocket[id] {
s.ThreadCount += tc
}
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
efficiencyCoreCount := 0
for _, threads := range threadsByCoreBySocket[id] {
s.ThreadCount += threads
if threads == 1 {
efficiencyCoreCount++
}

View File

@@ -1217,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter
```shell
curl http://localhost:11434/api/show -d '{
"model": "llama3.2"
"model": "llava"
}'
```
@@ -1260,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{
"tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
}
},
"capabilities": [
"completion",
"vision"
],
}
```

View File

@@ -5,7 +5,7 @@ import (
"time"
)
func assertEqual(t *testing.T, a interface{}, b interface{}) {
func assertEqual(t *testing.T, a any, b any) {
if a != b {
t.Errorf("Assert failed, expected %v, got %v", b, a)
}

13
fs/config.go Normal file
View File

@@ -0,0 +1,13 @@
package fs
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}

View File

@@ -134,7 +134,10 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
}
func (kv KV) OllamaEngineRequired() bool {
return kv.Architecture() == "gemma3"
return slices.Contains([]string{
"gemma3",
"mistral3",
}, kv.Architecture())
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
@@ -638,7 +641,7 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)

View File

@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
Prompt: "天空为什么是蓝色的?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
// Workaround deepseek context shifting bug
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
},

View File

@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
Model: "llama2",
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
"num_ctx": 128,
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
Model: "llama2",
Prompt: "Write me a story with a ton of emojis?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"temperature": 0,
"seed": 123,
"num_ctx": 128,

View File

@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
Model: "x/llama3.2-vision",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -75,7 +75,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -20,7 +20,7 @@ var (
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -28,7 +28,7 @@ var (
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the color of dirt brown?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of independence day?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the composition of air?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -62,6 +62,11 @@ type Cache interface {
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//

View File

@@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
c.cellRanges[dstSeq] = seqRange
}
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
return true
}
seqRange, ok := c.cellRanges[seq]
if !ok {
return false
}
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
if last == -1 {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil {
return ErrNotSupported
@@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// TODO(jessegross): We should check to see if removing the middle of the sequence will
// cause the sliding window to encompass tokens that we no longer have. If so, then we
// should return an error, which will trigger the runner to evaluate the full history and
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
// results in use after free, so we don't do it for now.
var offset int32
if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex
@@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
} else {
if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// TODO(jessegross): Need to be careful about data shared between sequences
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
return errors.New("shifting cells shared by multiple sequences not supported")
}
c.cells[i].pos += offset

View File

@@ -5,6 +5,7 @@ import (
"slices"
"testing"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@@ -300,9 +301,80 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
}
}
func TestCanResume(t *testing.T) {
backend := &testBackend{}
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
})
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)")
}
if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)")
}
if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)")
}
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
// shift window by adding position 4
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
})
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
}
type testBackend struct{}
func (b *testBackend) Config() ml.Config {
func (b *testBackend) Config() fs.Config {
panic("not implemented")
}
@@ -412,6 +484,14 @@ func (t *testTensor) Floats() []float32 {
return out
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
@@ -466,17 +546,15 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, di
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
@@ -528,6 +606,8 @@ func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
@@ -540,3 +620,5 @@ func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@@ -134,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("encoder cache does not support multiple sequences")
}
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
return true
}
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
c.encoderCached = false

View File

@@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
}
}
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
for _, cache := range c.caches {
if !cache.CanResume(seq, pos) {
return false
}
}
return true
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches {

View File

@@ -65,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -1371,6 +1372,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
{
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}
},
{
LLM_ARCH_UNKNOWN,
{

View File

@@ -69,6 +69,7 @@ enum llm_arch {
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};

View File

@@ -1277,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
case LLM_ARCH_MISTRAL3: break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -3537,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
case LLM_ARCH_MISTRAL3: break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4015,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_SOLAR:
case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2

View File

@@ -738,13 +738,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.blk.") == std::string::npos;
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
quantize &= name.find("v.") == std::string::npos;
quantize &= name.find("mm.") == std::string::npos;
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c)
}
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))

View File

@@ -1,17 +1,19 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Patrick Devine <patrick@infrahq.com>
Date: Fri, 14 Mar 2025 16:33:23 -0700
Subject: [PATCH] gemma3 quantization
Subject: [PATCH] add model quantizations
- gemma3
- mistral3
---
src/llama-arch.cpp | 19 +++++++++++++++++++
src/llama-arch.h | 1 +
src/llama-model.cpp | 7 +++++++
src/llama-quant.cpp | 9 +++++++++
4 files changed, 36 insertions(+)
src/llama-arch.cpp | 36 ++++++++++++++++++++++++++++++++++++
src/llama-arch.h | 2 ++
src/llama-model.cpp | 10 ++++++++++
src/llama-quant.cpp | 4 ++++
4 files changed, 52 insertions(+)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index b6f20286..b443fcd3 100644
index b6f20286..13a0a988 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
@@ -22,7 +24,15 @@ index b6f20286..b443fcd3 100644
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
@@ -64,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
+ { LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -804,6 +806,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
@@ -47,8 +57,31 @@ index b6f20286..b443fcd3 100644
{
LLM_ARCH_STARCODER2,
{
@@ -1352,6 +1372,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
+ {
+ LLM_ARCH_MISTRAL3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ }
+ },
{
LLM_ARCH_UNKNOWN,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index ec742224..aad92a5d 100644
index ec742224..8476ae0a 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -41,6 +41,7 @@ enum llm_arch {
@@ -59,8 +92,16 @@ index ec742224..aad92a5d 100644
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -68,6 +69,7 @@ enum llm_arch {
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
+ LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ab1a07d1..70183041 100644
index ab1a07d1..db4f2685 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
@@ -73,7 +114,15 @@ index ab1a07d1..70183041 100644
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
+ case LLM_ARCH_MISTRAL3: break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
@@ -83,7 +132,23 @@ index ab1a07d1..70183041 100644
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
+ case LLM_ARCH_MISTRAL3: break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_SOLAR:
+ case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
@@ -92,21 +157,16 @@ index ab1a07d1..70183041 100644
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 6eb1da08..d2f3a510 100644
index 6eb1da08..ebcbafa1 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ // don't quantize vision stuff
+ quantize &= name.find("v.blk.") == std::string::npos;
+
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
+ quantize &= name.find("v.") == std::string::npos;
+ quantize &= name.find("mm.") == std::string::npos;
+
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -0,0 +1,75 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Wed, 2 Apr 2025 15:26:15 -0700
Subject: [PATCH] metal: add op_neg
---
ggml/src/ggml-metal/ggml-metal.m | 15 +++++++++++++++
ggml/src/ggml-metal/ggml-metal.metal | 7 +++++++
2 files changed, 22 insertions(+)
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index e4c093f9..d8422f1b 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
+ GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_UNARY_OP_NEG:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index f38909d0..bb0ff668 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -945,6 +945,13 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
+kernel void kernel_neg(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = -src0[tpig];
+}
+
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -675,9 +675,32 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
type DoneReason int
const (
// DoneReasonStop indicates the completion stopped naturally
DoneReasonStop DoneReason = iota
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed
)
func (d DoneReason) String() string {
switch d {
case DoneReasonLength:
return "length"
case DoneReasonStop:
return "stop"
default:
return "" // closed
}
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason string `json:"done_reason"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
@@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line

View File

@@ -9,22 +9,12 @@ import (
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/fs"
)
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
Config() Config
Config() fs.Config
Get(name string) Tensor
NewContext() Context
NewContextSize(size int) Context
@@ -128,6 +118,7 @@ type Tensor interface {
Bytes() []byte
Floats() []float32
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
@@ -142,7 +133,10 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
@@ -157,9 +151,13 @@ type Tensor interface {
Unpad(ctx Context, shape ...int) Tensor
Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
}
// ScaledDotProductAttention implements a fused attention
@@ -224,7 +222,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)

View File

@@ -24,7 +24,8 @@ import (
"unsafe"
"github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup"
@@ -41,7 +42,7 @@ func devices() []*C.struct_ggml_backend_device {
}
type Backend struct {
meta *fs.GGML
meta *fsggml.GGML
sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor
@@ -58,7 +59,7 @@ type Backend struct {
}
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
meta, n, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
@@ -182,7 +183,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
maxTensors += blocks * 2
type tensor struct {
source *fs.Tensor
source *fsggml.Tensor
target string
}
@@ -413,7 +414,7 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Config() ml.Config {
func (b *Backend) Config() fs.Config {
return b.meta.KV()
}
@@ -710,6 +711,13 @@ func (t *Tensor) DType() ml.DType {
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -717,6 +725,27 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
}
shape := make([]C.int64_t, C.GGML_MAX_DIMS)
for i := range C.GGML_MAX_DIMS {
if i == dim {
shape[i] = C.int64_t(t.Dim(i) * n)
} else {
shape[i] = C.int64_t(t.Dim(i))
}
}
tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
return &Tensor{
b: t.b,
t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
}
}
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
if len(s) > 0 {
return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
@@ -853,6 +882,20 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sin(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cos(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -941,6 +984,13 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1009,3 +1059,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
}
}

View File

@@ -3083,6 +3083,13 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_neg(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_NEG:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

View File

@@ -945,6 +945,13 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_neg(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -16,7 +16,8 @@ import (
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
@@ -83,10 +84,10 @@ func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(ml.Config) (Model, error))
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(ml.Config) (Model, error)) {
func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
@@ -131,14 +132,14 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
meta, _, err := fs.Decode(r, -1)
meta, _, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {

View File

@@ -7,7 +7,8 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
@@ -139,7 +140,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
}
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{})
tp, err := getTextProcessor(fsggml.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
@@ -148,10 +149,10 @@ func TestGetTextProcessor(t *testing.T) {
t.Error("expected nil tp")
}
models["dummy"] = func(ml.Config) (Model, error) {
models["dummy"] = func(fs.Config) (Model, error) {
return notTextProcessorModel{}, nil
}
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {

View File

@@ -3,6 +3,7 @@ package gemma2
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -35,10 +36,9 @@ const (
gemma27BLayerCount = 46
)
func New(c ml.Config) (model.Model, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -6,6 +6,7 @@ import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -52,10 +53,9 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -3,6 +3,7 @@ package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -10,7 +11,7 @@ import (
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
@@ -27,7 +28,7 @@ type TextModel struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
*TextConfig
}
const (
@@ -40,12 +41,11 @@ const (
cacheTypeCausal
)
func newTextModel(c ml.Config) *TextModel {
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -55,7 +55,7 @@ func newTextModel(c ml.Config) *TextModel {
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
@@ -84,7 +84,7 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
@@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
ropeBase = m.TextConfig.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
}
type TextMLP struct {
@@ -134,7 +134,7 @@ type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -148,7 +148,7 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
var except []int
@@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -3,6 +3,7 @@ package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
@@ -111,7 +112,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{

View File

@@ -3,7 +3,7 @@ package gemma3
import (
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
@@ -11,7 +11,7 @@ type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),

View File

@@ -5,6 +5,7 @@ import (
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -30,7 +31,7 @@ type Model struct {
*Options
}
func New(c ml.Config) (model.Model, error) {
func New(c fs.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}

View File

@@ -0,0 +1,56 @@
package mistral3
import (
"image"
_ "image/jpeg"
_ "image/png"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
}
}
// ProcessImage prepares an image for the vision model by:
// 1. Compositing transparent images
// 2. Resizing to fit model constraints while preserving aspect ratio
// 3. Normalizing pixel values
// Returns normalized image data and the final size in pixels
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) {
img = imageproc.Composite(img)
size := img.Bounds().Size()
ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge))
if ratio > 1.0 {
size = image.Point{
int(math.Floor(float64(size.X) / ratio)),
int(math.Floor(float64(size.Y) / ratio)),
}
}
patchesX := (size.X-1)/p.patchSize + 1
patchesY := (size.Y-1)/p.patchSize + 1
size = image.Point{
patchesX * p.patchSize,
patchesY * p.patchSize,
}
img = imageproc.Resize(img, size, imageproc.ResizeBilinear)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, size, nil
}

View File

@@ -0,0 +1,189 @@
package mistral3
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {
return nil, err
}
m := &Model{
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
type PatchMerger struct {
MergingLayer *nn.Linear `gguf:"merging_layer"`
}
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor {
d := visionOutputs.Dim(0)
imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d)
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d)
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
return pm.MergingLayer.Forward(ctx, reshaped)
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"linear_1"`
Linear2 *nn.Linear `gguf:"linear_2"`
PatchMerger *PatchMerger `gguf:"patch_merger"`
spatialMergeSize int
eps float32
patchSize int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps)
patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize}
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize)
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
visionOutputs = visionOutputs.GELU(ctx)
return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize}
}
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
eps: c.Float("text_config.rms_norm_eps", 1e-5),
patchSize: int(c.Uint("vision.patch_size", 14)),
}
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, size, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
for i := range rows {
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
} else {
// [IMG_BREAK]
result = append(result, input.Input{Token: 12})
}
}
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("mistral3", New)
}

View File

@@ -0,0 +1,177 @@
package mistral3
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
for _, image := range batch.Multimodal {
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func NewTextModel(c fs.Config) (*TextModel, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
textModel := &TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
return textModel, nil
}

View File

@@ -0,0 +1,186 @@
package mistral3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
maxPatchesPerSide := m.imageSize / m.patchSize
frequencies := m.headDim / 2
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
for i := range frequencies {
for j := range maxPatchesPerSide {
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
if i%2 == 0 {
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
} else {
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
}
}
}
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
h = h.Repeat(ctx, 1, maxPatchesPerSide)
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
w = w.Repeat(ctx, 2, maxPatchesPerSide)
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
return inverseFrequencies.Rows(ctx, positionIDs)
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesW * numPatchesH
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
// Prepare position IDs for 2D rope
positions := make([]int32, numPatches)
for h := range numPatchesH {
for w := range numPatchesW {
idx := h*numPatchesW + w
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
}
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return hiddenStates
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
},
}
}

View File

@@ -8,6 +8,7 @@ import (
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -32,7 +33,7 @@ const (
selfAttentionLayer
)
func New(c ml.Config) (model.Model, error) {
func New(c fs.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")

View File

@@ -4,6 +4,7 @@ import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -220,7 +221,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask
return m.Output.Forward(ctx, hiddenState)
}
func newTextModel(c ml.Config) *TextModel {
func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer

View File

@@ -4,6 +4,7 @@ import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
@@ -185,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
@@ -213,7 +214,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
return hiddenState.Concat(ctx, hiddenStates, 0)
}
func newVisionModel(c ml.Config) *VisionModel {
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},

View File

@@ -8,14 +8,14 @@ import (
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/fs"
)
type ImageProcessor struct {
imageSize, numChannels, maxNumTiles int
}
func newImageProcessor(c ml.Config) ImageProcessor {
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
numChannels: int(c.Uint("vision.num_channels")),

View File

@@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -1,68 +0,0 @@
package pixtral
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"io"
"math"
"github.com/ollama/ollama/model/imageproc"
)
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
return image.Point{
(imageSize.X-1)/patchSize.X + 1,
(imageSize.Y-1)/patchSize.Y + 1,
}
}
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds()
le := float64(longestEdge)
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
newSize := img.Bounds().Max
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
}
}
tokens := getNumImageTokens(newSize, patchSize)
return image.Point{
tokens.X * patchSize.X,
tokens.Y * patchSize.Y,
}
}
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
if format == "png" {
img = imageproc.Composite(img)
}
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
// todo should be ResizeBicubic, but it doesn't exist
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
}
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
img, format, err := image.Decode(imageData)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
}
longestEdge := 1024
patchSize := image.Point{16, 16}
img = resizeImage(img, format, longestEdge, patchSize)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
opts := map[string]any{}
return data, opts, nil
}

View File

@@ -1,219 +0,0 @@
package pixtral
import (
"bytes"
"encoding/binary"
"image"
"image/png"
"math"
"os"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGetNumImageTokens(t *testing.T) {
type numImageTokensCase struct {
ImageSize image.Point
PatchSize image.Point
Expected image.Point
}
cases := []numImageTokensCase{
{
ImageSize: image.Point{1024, 764},
PatchSize: image.Point{16, 16},
Expected: image.Point{64, 48},
},
{
ImageSize: image.Point{800, 600},
PatchSize: image.Point{16, 16},
Expected: image.Point{50, 38},
},
{
ImageSize: image.Point{640, 480},
PatchSize: image.Point{16, 16},
Expected: image.Point{40, 30},
},
{
ImageSize: image.Point{320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{20, 13},
},
{
ImageSize: image.Point{1320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{83, 13},
},
{
ImageSize: image.Point{2000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{125, 13},
},
{
ImageSize: image.Point{10000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{625, 13},
},
{
ImageSize: image.Point{1131, 577},
PatchSize: image.Point{16, 16},
Expected: image.Point{71, 37},
},
{
ImageSize: image.Point{16, 16},
PatchSize: image.Point{16, 16},
Expected: image.Point{1, 1},
},
}
for _, c := range cases {
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestGetResizeOutputImageSize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Point
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 768},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 624},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{304, 208},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 288},
},
}
for _, c := range cases {
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestResize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Image
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
},
{
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
},
}
for _, c := range cases {
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
if actual.Bounds() != c.Expected.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
}
}
}
func TestPreprocess(t *testing.T) {
type preprocessCase struct {
TestImage image.Image
ExpectedLen int
}
cases := []preprocessCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
ExpectedLen: 16 * 16 * 3 * 1,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
ExpectedLen: 1024 * 1024 * 3 * 1,
},
}
for _, c := range cases {
var buf bytes.Buffer
err := png.Encode(&buf, c.TestImage)
if err != nil {
t.Fatal(err)
}
imgData, _, err := Preprocess(&buf)
if err != nil {
t.Fatalf("error processing: %q", err)
}
switch len(imgData) {
case 0:
t.Errorf("no image data returned")
case c.ExpectedLen:
// ok
default:
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
}
}
}
func TestPreprocessImages(t *testing.T) {
for _, testFile := range []string{"flight.png", "sportsball.png"} {
f, err := os.Open(testFile)
if err != nil {
t.Skipf("skipping test, no test image found at %s", testFile)
}
defer f.Close()
imgData, _, err := Preprocess(f)
if err != nil {
t.Fatalf("error processing: %q", err)
}
byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
for i, f := range imgData {
binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
}
outputPath := "processed_" + testFile + ".bin"
err = os.WriteFile(outputPath, byteData, 0o644)
if err != nil {
t.Fatalf("error writing processed image: %q", err)
}
}
}

View File

@@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@@ -1,29 +1,23 @@
package model
import (
"iter"
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
@@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
@@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
@@ -100,26 +81,17 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
@@ -129,8 +101,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
}
slog.Debug("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
@@ -142,34 +112,24 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
heap.Push(q, pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
@@ -181,24 +141,36 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
heap.Push(q, pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("missing token", "token", string(merge.runes))
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
@@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
@@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
@@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
return NewSentencePieceModel(preTokenizer, &v)
return NewSentencePieceModel(&v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) {
}
})
}
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
"<0xEA>",
"<0x41>",
"<0xC3>",
"<0xA3>",
},
Types: []uint32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
},
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePieceModel(vocab)
tests := []struct {
name string
ids []int32
expected string
}{
{
name: "single byte token",
ids: []int32{1},
expected: "\xea",
},
{
name: "ASCII byte token",
ids: []int32{2},
expected: "A",
},
{
name: "multiple byte tokens forming UTF-8 character",
ids: []int32{3, 4},
expected: "ã",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := spm.Decode(tt.ids)
if err != nil {
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
}
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -25,7 +25,7 @@ var finishReasonToolCalls = "tool_calls"
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Param any `json:"param"`
Code *string `json:"code"`
}
@@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
options := make(map[string]interface{})
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:

View File

@@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]interface{}{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},

View File

@@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin

View File

@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
type ErrReprocessInputs struct {
Inputs []input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
inputLen := len(slot.Inputs)
discard := c.ShiftDiscard(inputLen, numKeep)
if discard <= 0 {
return nil
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
var shiftFailed bool
for i := numKeep + discard; i < len(slot.Inputs); i++ {
if c.lc.KvCacheCanShift() {
// For models that support shifting, attempt to shift the KV cache
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
shiftFailed = true
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
} else {
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
}
} else {
// For models that don't support shifting
shiftFailed = true
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
}
if shiftFailed {
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Clear the entire KV cache
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
// Reset the slot inputs since we've cleared the cache
slot.Inputs = []input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
// Standard shift succeeded - update input array
for i := numKeep + discard; i < inputLen; i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
slot.Inputs = slot.Inputs[:inputLen-discard]
return nil
}

View File

@@ -83,7 +83,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason string
doneReason llm.DoneReason
// Metrics
startProcessingTime time.Time
@@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason string) {
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
@@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit")
s.removeSequence(seqIdx, llm.DoneReasonLength)
continue
}
@@ -389,8 +389,16 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Continue processing as normal
continue
} else {
return err
}
}
} else {
break
}
@@ -474,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
seq.embedding <- embed
s.removeSequence(i, "")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -491,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, "stop")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -522,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -535,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
if !flushPending(seq) {
s.removeSequence(i, "connection")
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
@@ -647,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
@@ -691,7 +696,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -703,6 +708,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -715,6 +721,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}

View File

@@ -118,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
}
if c.cache != nil {
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
if err != nil {
// Some models don't support partial erasure
@@ -225,6 +229,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
return count
}
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
// we don't split up a SameBatch
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
@@ -239,6 +245,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
return discard
}
type ErrReprocessInputs struct {
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
@@ -258,11 +272,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if c.cache != nil {
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
if err != nil {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, -1)
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
}

View File

@@ -1,10 +1,13 @@
package ollamarunner
import (
"errors"
"fmt"
"image"
"testing"
"time"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@@ -425,3 +428,92 @@ func TestLoadCacheSlot(t *testing.T) {
})
}
}
// Mock implementation of the Cache interface
type mockCache struct {
shouldFail bool
}
// Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail {
return fmt.Errorf("mock cache removal error")
}
return nil
}
// Stub implementations for other interface methods
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
wantInputsLen int
}{
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
wantInputsLen: 6, // After discarding 4 tokens
},
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
wantInputsLen: 0, // Original inputs should be cleared
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockCache{shouldFail: tt.cacheErr}
c := InputCache{
numCtx: tt.numCtx,
cache: mock,
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)
err := c.ShiftCacheSlot(slot, tt.numKeep)
if tt.wantErr != nil {
if err == nil {
t.Errorf("Expected error but got nil")
return
}
if !errors.As(err, &tt.wantErr) {
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
}
} else if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(slot.Inputs) != tt.wantInputsLen {
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
}
})
}
}

View File

@@ -82,7 +82,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason string
doneReason llm.DoneReason
// Metrics
startProcessingTime time.Time
@@ -115,16 +115,41 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs))
}
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx
promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
sameBatch := 0
for i, inp := range inputs {
if sameBatch > 0 {
sameBatch--
if promptStart == int32(i) {
promptStart++
}
} else if promptStart == int32(i) {
break
}
if inp.SameBatch != 0 {
if int32(i) < params.numKeep {
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
}
sameBatch = inp.SameBatch
}
}
if promptStart >= int32(len(inputs)) {
return nil, errors.New("entire prompt removed by truncation")
}
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
newInputs = append(newInputs, inputs[promptStart:]...)
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
inputs = newInputs
@@ -267,6 +292,9 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
@@ -313,7 +341,7 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason string) {
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
@@ -351,14 +379,19 @@ func (s *Server) processBatch() error {
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
resumeSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, "limit")
s.removeSequence(seqIdx, llm.DoneReasonLength)
continue
}
@@ -369,16 +402,23 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize
for j, inp := range seq.inputs {
for i, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
// will cause a break if we have existing inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if len(seq.pendingInputs)+minBatch > batchSize {
// Stop if the required batch would put us over the total batch size (including tokens
// added by other sequences). If we haven't been able to add anything yet then pick up
// here again for the next batch to avoid starvation, though we can opportunistically
// check if other sequences can still squeeze something in.
if len(batchInputs)+minBatch > batchSize {
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
resumeSeq = seqIdx
}
break
}
@@ -392,9 +432,17 @@ func (s *Server) processBatch() error {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
continue
} else {
return err
}
}
}
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
@@ -405,7 +453,7 @@ func (s *Server) processBatch() error {
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -414,6 +462,12 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
s.nextSeq = seqIdx + 1
}
if len(batchInputs) == 0 {
return nil
}
@@ -456,7 +510,7 @@ func (s *Server) processBatch() error {
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -474,7 +528,7 @@ func (s *Server) processBatch() error {
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, "stop")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -510,7 +564,7 @@ func (s *Server) processBatch() error {
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -523,7 +577,7 @@ func (s *Server) processBatch() error {
}
if !flushPending(seq) {
s.removeSequence(i, "connection")
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
@@ -588,7 +642,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -600,6 +654,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -613,6 +668,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
@@ -634,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,

View File

@@ -35,17 +35,11 @@ var (
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errInsecureProtocol = errors.New("insecure protocol http")
)
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct {
Insecure bool
Username string
@@ -66,52 +60,83 @@ type Model struct {
System string
License []string
Digest string
Options map[string]interface{}
Options map[string]any
Messages []api.Message
Template *template.Template
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
r, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
if err == nil {
defer r.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
f, _, err := ggml.Decode(r, 0)
if err != nil {
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
} else {
slog.Error("couldn't decode ggml", "error", err)
continue
}
} else {
slog.Error("couldn't open model file", "error", err)
}
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
if m.Template == nil {
return capabilities
}
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
default:
return capabilities
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(want ...model.Capability) error {
available := m.Capabilities()
var errs []error
// Map capabilities to their corresponding error
capToErr := map[model.Capability]error{
model.CapabilityCompletion: errCapabilityCompletion,
model.CapabilityTools: errCapabilityTools,
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
}
for _, cap := range want {
err, ok := capToErr[cap]
if !ok {
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
}
if !slices.Contains(available, cap) {
errs = append(errs, err)
}
}
if err := errors.Join(errs...); err != nil {
if len(errs) > 0 {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}

360
server/images_test.go Normal file
View File

@@ -0,0 +1,360 @@
package server
import (
"bytes"
"encoding/binary"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// Constants for GGUF magic bytes and version
var (
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
ggufVer = uint32(3) // Version 3
)
// Helper function to create mock GGUF data
func createMockGGUFData(architecture string, vision bool) []byte {
var buf bytes.Buffer
// Write GGUF header
buf.Write(ggufMagic)
binary.Write(&buf, binary.LittleEndian, ggufVer)
// Write tensor count (0 for our test)
var numTensors uint64 = 0
binary.Write(&buf, binary.LittleEndian, numTensors)
// Calculate number of metadata entries
numMetaEntries := uint64(1) // architecture entry
if vision {
numMetaEntries++
}
// Add embedding entry if architecture is "bert"
if architecture == "bert" {
numMetaEntries++
}
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
// Write architecture metadata
archKey := "general.architecture"
keyLen := uint64(len(archKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(archKey)
// String type (8)
var strType uint32 = 8
binary.Write(&buf, binary.LittleEndian, strType)
// String length
strLen := uint64(len(architecture))
binary.Write(&buf, binary.LittleEndian, strLen)
buf.WriteString(architecture)
if vision {
visionKey := architecture + ".vision.block_count"
keyLen = uint64(len(visionKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(visionKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var countVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, countVal)
}
// Write embedding metadata if architecture is "bert"
if architecture == "bert" {
poolKey := architecture + ".pooling_type"
keyLen = uint64(len(poolKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(poolKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var poolingVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, poolingVal)
}
return buf.Bytes()
}
func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create different types of mock model files
completionModelPath := filepath.Join(tempDir, "model.bin")
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
// Create a simple model file for tests that don't depend on GGUF content
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
}
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools and insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create vision model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
tests := []struct {
name string
model Model
checkCaps []model.Capability
expectedErrMsg string
}{
{
name: "completion model without tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools},
expectedErrMsg: "does not support tools",
},
{
name: "model with all needed capabilities",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model missing insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityInsert},
expectedErrMsg: "does not support insert",
},
{
name: "model missing vision capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
expectedErrMsg: "does not support vision",
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "unknown capability",
model: Model{
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test CheckCapabilities method
err := tt.model.CheckCapabilities(tt.checkCaps...)
if tt.expectedErrMsg == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
}
}
})
}
}

View File

@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
return err
}
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
break
}
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
t.update(l, 0, err)
}
wg.Done()
}()
@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
}
md := blob.DigestFromBytes(m.Data)
@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
size += m.Config.Size
}
for _, l := range m.Layers {
size += l.Size
}
return size
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
// The response is a sequence of chunksums.
//
// Example:
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// The blobURL is the URL to download the chunks from.
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,

View File

@@ -9,17 +9,14 @@ import (
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/testutil"
@@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
@@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), "x")
err := rc.Pull(t.Context(), "http://example.com/a/b")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
@@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
@@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
check := testutil.Checker(t)
rc, _ := newRegistryClient(t, nil)
// make a blob and link it
d := blob.DigestFromBytes("{}")
err := blob.PutBytes(rc.Cache, d, "{}")
check(err)
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
check(err)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
_, err = rc.ResolveLocal("single")
check(err)
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
check(err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
@@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
@@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
// Many tests from here out, in this file are based on a single blob, "abc",
// with the checksum of its sha256 hash. The checksum is:
//
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
//
// Using the literal value instead of a constant with fmt.Xprintf calls proved
// to be the most readable and maintainable approach. The sum is consistently
// used in the tests and unique so searches do not yield false positives.
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
func checkRequest(t *testing.T, req *http.Request, method, path string) {
t.Helper()
if got := req.URL.Path; got != path {
t.Errorf("URL = %q, want %q", got, path)
}
if req.Method != method {
t.Errorf("Method = %q, want %q", req.Method, method)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(h)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
t.Log("trace:", l.Digest.Short(), n, err)
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
rc := &Registry{
Cache: cache,
HTTPClient: &http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, s.Listener.Addr().String())
},
}},
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
return rc, ctx
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
func TestPullChunked(t *testing.T) {
var steps atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch steps.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
t.Logf("writing c")
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
default:
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
err := c.Pull(ctx, "http://o.com/library/abc")
testutil.Check(t, err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
testutil.Check(t, err)
if g := steps.Load(); g != 4 {
t.Fatalf("got %d steps, want 4", g)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
func TestPullCached(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
})
check := testutil.Checker(t)
// Premeptively cache the blob
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
check(err)
err = blob.PutBytes(c.Cache, d, []byte("abc"))
check(err)
// Pull only the manifest, which should be enough to resolve the cached blob
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
}
func TestPullManifestError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Error("expected error because of missing chunks")
t.Fatalf("expected error")
}
var got *Error
if !errors.Is(err, ErrModelNotFound) {
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
}
}
func TestPullLayerError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `!`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var want *json.SyntaxError
if !errors.As(err, &want) {
t.Fatalf("err = %T, want %T", err, want)
}
}
func TestPullLayerChecksumError(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3:
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1
c.ChunkingThreshold = 1 // force chunking
var written atomic.Int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
written.Add(n)
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
var got *Error
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
t.Fatalf("err = %v, want %v", err, got)
}
if g := written.Load(); g != 1 {
t.Fatalf("wrote %d bytes, want 1", g)
}
}
func TestPullChunksumStreamError(t *testing.T) {
var step atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
// Write one valid chunksum and one invalid chunksum
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
fmt.Fprint(w, "sha256:!") // invalid
case 3:
io.WriteString(w, "ab")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
got := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(got, ErrIncomplete) {
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
}
}
type flushAfterWriter struct {
w io.Writer
}
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
f.w.(http.Flusher).Flush() // panic if not a flusher
return
}
func TestPullChunksumStreaming(t *testing.T) {
csr, csw := io.Pipe()
defer csw.Close()
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
_, err := io.Copy(fw, csr)
if err != nil {
t.Errorf("copy: %v", err)
}
case 3:
io.WriteString(w, "ab")
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
update := make(chan int64, 1)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if n > 0 {
update <- n
}
},
})
errc := make(chan error, 1)
go func() {
errc <- c.Pull(ctx, "http://o.com/library/abc")
}()
// Send first chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
if g := <-update; g != 2 {
t.Fatalf("got %d, want 2", g)
}
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 {
t.Fatalf("got %d, want 1", g)
}
csw.Close()
testutil.Check(t, <-errc)
}
func TestPullChunksumsCached(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1 // force serial processing of chunksums
c.ChunkingThreshold = 1 // force chunking
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Cancel the pull after the first chunksum is processed, but before
// the second chunksum is processed (which is waiting because
// MaxStreams=1). This should cause the second chunksum to error out
// leaving the blob incomplete.
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if n > 0 {
cancel()
}
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(err, context.Canceled) {
t.Fatalf("err = %v, want %v", err, context.Canceled)
}
_, err = c.Cache.Resolve("o.com/library/abc:latest")
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want nil", err)
}
// Reset state and pull again to ensure the blob chunks that should
// have been cached are, and the remaining chunk was downloaded, making
// the blob complete.
step.Store(0)
var written atomic.Int64
var cached atomic.Int64
ctx = WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if errors.Is(err, ErrCached) {
cached.Add(n)
}
written.Add(n)
},
})
check := testutil.Checker(t)
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 3 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g)
}
}

View File

@@ -72,7 +72,7 @@ var (
errBadTemplate = errors.New("template error")
)
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
@@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
@@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(name.String())
m, err := GetModel(name.String())
if err != nil {
switch {
case errors.Is(err, fs.ErrNotExist):
@@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
s.sched.expireRunner(model)
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
caps := []Capability{CapabilityCompletion}
caps := []model.Capability{model.CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
caps = append(caps, model.CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(model)
isMllama := checkMllamaModelFamily(m)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
@@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
if isMllama && len(model.ProjectorPaths) > 0 {
if isMllama && len(m.ProjectorPaths) > 0 {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
@@ -312,7 +312,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -818,6 +818,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
}
@@ -825,7 +826,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
cs := 30
for k, v := range m.Options {
switch val := v.(type) {
case []interface{}:
case []any:
for _, nv := range val {
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
}
@@ -1335,7 +1336,7 @@ func Serve(ln net.Listener) error {
return nil
}
func waitForStream(c *gin.Context, ch chan interface{}) {
func waitForStream(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/json")
for resp := range ch {
switch r := resp.(type) {
@@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
caps := []Capability{CapabilityCompletion}
caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools)
caps = append(caps, model.CapabilityTools)
}
name := model.ParseName(req.Model)
@@ -1536,7 +1537,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
@@ -1546,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}

View File

@@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) {
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
@@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) {
mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "done",
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
@@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) {
{
Content: `, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "tool_call",
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 3,
PromptEvalDuration: 1,
},
@@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) {
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)
type LlmRequest struct {
@@ -37,7 +38,7 @@ type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
unloadedCh chan any
loaded map[string]*runnerRef
loadedMu sync.Mutex
@@ -67,7 +68,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
unloadedCh: make(chan any, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: discover.GetGPUInfo,
@@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Embedding models should always be loaded with parallel=1
if pending.model.CheckCapabilities(CapabilityCompletion) != nil {
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
numParallel = 1
}
@@ -617,8 +618,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
// a before and after GPU memory allocation. The returned channel
// will be notified when we're done waiting, or have timed out and should
// proceed anyway
func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
finished := make(chan interface{}, 1)
func (runner *runnerRef) waitForVRAMRecovery() chan any {
finished := make(chan any, 1)
// CPU or Metal don't need checking, so no waiting required
// windows can page VRAM, only cuda currently can report accurate used vram usage

15
types/model/capability.go Normal file
View File

@@ -0,0 +1,15 @@
package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
)
func (c Capability) String() string {
return string(c)
}