From f15ffc432061e3d96b3412219a3a0f673b579a12 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 13 May 2025 17:26:46 -0700 Subject: [PATCH 01/26] llm: Make "POST predict" error message more informative "POST predict" basically means that the runner has crashed, which can have many reasons. However, many people think this is a specific error and either report only this message or group together unrelated bugs. This replaces it with a more friendly and helpful message. --- llm/server.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llm/server.go b/llm/server.go index 4abb569f..373f6fae 100644 --- a/llm/server.go +++ b/llm/server.go @@ -797,7 +797,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu res, err := http.DefaultClient.Do(serverReq) if err != nil { - return fmt.Errorf("POST predict: %v", err) + slog.Error("post predict", "error", err) + return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details") } defer res.Body.Close() From aaa7818000c42a82fc030212c35ef83f9799efd7 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 24 Apr 2025 11:48:49 -0700 Subject: [PATCH 02/26] ggml: Export GPU UUIDs This enables matching up devices and information reported by the backend with system management libraries such as nvml to get accurate free memory reporting. --- .../patches/0017-ggml-Export-GPU-UUIDs.patch | 102 ++++++++++++++++++ ml/backend.go | 8 ++ ml/backend/ggml/ggml.go | 6 ++ ml/backend/ggml/ggml/include/ggml-backend.h | 1 + .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++ .../ggml/ggml/src/ggml-metal/ggml-metal.m | 1 + 6 files changed, 151 insertions(+) create mode 100644 llama/patches/0017-ggml-Export-GPU-UUIDs.patch diff --git a/llama/patches/0017-ggml-Export-GPU-UUIDs.patch b/llama/patches/0017-ggml-Export-GPU-UUIDs.patch new file mode 100644 index 00000000..a2539034 --- /dev/null +++ b/llama/patches/0017-ggml-Export-GPU-UUIDs.patch @@ -0,0 +1,102 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jesse Gross +Date: Thu, 24 Apr 2025 14:48:51 -0700 +Subject: [PATCH] ggml: Export GPU UUIDs + +This enables matching up devices and information reported by the backend +with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml). +--- + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++ + ggml/src/ggml-metal/ggml-metal.m | 1 + + 3 files changed, 35 insertions(+) + +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index 74e46716..a880df33 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -152,6 +152,7 @@ extern "C" { + struct ggml_backend_dev_props { + const char * name; + const char * description; ++ const char * uuid; + size_t memory_free; + size_t memory_total; + enum ggml_backend_dev_type type; +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index cb0d8528..4c829153 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; ++ std::string uuid; + }; + + static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { +@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t + return ctx->description.c_str(); + } + ++static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) { ++ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ++ return ctx->uuid.c_str(); ++} ++ + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); +@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend + static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cuda_device_get_name(dev); + props->description = ggml_backend_cuda_device_get_description(dev); ++ props->uuid = ggml_backend_cuda_device_get_uuid(dev); + props->type = ggml_backend_cuda_device_get_type(dev); + ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); + +@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; + ++ #if !defined(GGML_USE_HIP) ++ char uuid[64]; ++ snprintf(uuid, sizeof(uuid), ++ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", ++ (unsigned char)prop.uuid.bytes[0], ++ (unsigned char)prop.uuid.bytes[1], ++ (unsigned char)prop.uuid.bytes[2], ++ (unsigned char)prop.uuid.bytes[3], ++ (unsigned char)prop.uuid.bytes[4], ++ (unsigned char)prop.uuid.bytes[5], ++ (unsigned char)prop.uuid.bytes[6], ++ (unsigned char)prop.uuid.bytes[7], ++ (unsigned char)prop.uuid.bytes[8], ++ (unsigned char)prop.uuid.bytes[9], ++ (unsigned char)prop.uuid.bytes[10], ++ (unsigned char)prop.uuid.bytes[11], ++ (unsigned char)prop.uuid.bytes[12], ++ (unsigned char)prop.uuid.bytes[13], ++ (unsigned char)prop.uuid.bytes[14], ++ (unsigned char)prop.uuid.bytes[15] ++ ); ++ dev_ctx->uuid = uuid; ++ #else ++ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16); ++ #endif ++ + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index 1b56f858..ee4f2dcb 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen + static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); ++ props->uuid = "0"; + props->type = ggml_backend_metal_device_get_type(dev); + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = (struct ggml_backend_dev_caps) { diff --git a/ml/backend.go b/ml/backend.go index 65f16948..2df6c892 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -124,6 +124,10 @@ type DeviceMemory struct { // may not be persistent across instances of the runner. Name string + // UUID is a unique persistent identifier for the device for matching + // with system management libraries + UUID string + // Weights is the per-layer memory needed for the model weights. Weights []Memory @@ -152,6 +156,10 @@ func (m DeviceMemory) LogValue() slog.Value { attrs = append(attrs, slog.Any("Graph", m.Graph)) } + if len(attrs) > 0 && m.UUID != "" { + attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...) + } + return slog.GroupValue(attrs...) } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 76172ae1..5a9fe67e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -136,6 +136,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d)) + var props C.struct_ggml_backend_dev_props + C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props) + requiredMemory.CPU.UUID = C.GoString(props.uuid) requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) @@ -150,6 +153,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { }) btDeviceMemory[bt] = &requiredMemory.GPUs[i] requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) + var props C.struct_ggml_backend_dev_props + C.ggml_backend_dev_get_props(d, &props) + requiredMemory.GPUs[i].UUID = C.GoString(props.uuid) requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) } diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index 74e46716..a880df33 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -152,6 +152,7 @@ extern "C" { struct ggml_backend_dev_props { const char * name; const char * description; + const char * uuid; size_t memory_free; size_t memory_total; enum ggml_backend_dev_type type; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index cb0d8528..4c829153 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context { int device; std::string name; std::string description; + std::string uuid; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t return ctx->description.c_str(); } +static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ctx->uuid.c_str(); +} + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); @@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); + props->uuid = ggml_backend_cuda_device_get_uuid(dev); props->type = ggml_backend_cuda_device_get_type(dev); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); @@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; + #if !defined(GGML_USE_HIP) + char uuid[64]; + snprintf(uuid, sizeof(uuid), + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + (unsigned char)prop.uuid.bytes[0], + (unsigned char)prop.uuid.bytes[1], + (unsigned char)prop.uuid.bytes[2], + (unsigned char)prop.uuid.bytes[3], + (unsigned char)prop.uuid.bytes[4], + (unsigned char)prop.uuid.bytes[5], + (unsigned char)prop.uuid.bytes[6], + (unsigned char)prop.uuid.bytes[7], + (unsigned char)prop.uuid.bytes[8], + (unsigned char)prop.uuid.bytes[9], + (unsigned char)prop.uuid.bytes[10], + (unsigned char)prop.uuid.bytes[11], + (unsigned char)prop.uuid.bytes[12], + (unsigned char)prop.uuid.bytes[13], + (unsigned char)prop.uuid.bytes[14], + (unsigned char)prop.uuid.bytes[15] + ); + dev_ctx->uuid = uuid; + #else + dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16); + #endif + ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index 1b56f858..ee4f2dcb 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_metal_device_get_name(dev); props->description = ggml_backend_metal_device_get_description(dev); + props->uuid = "0"; props->type = ggml_backend_metal_device_get_type(dev); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = (struct ggml_backend_dev_caps) { From 65f10c2823d540837a9e79202522957194377735 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 30 May 2025 15:18:09 -0700 Subject: [PATCH 03/26] tools: resiliency upgrade to name and arg extraction from template (#10917) --- tools/tools_utils.go | 37 +++++------ tools/tools_utils_test.go | 131 ++++++++++++++++++++++++-------------- 2 files changed, 98 insertions(+), 70 deletions(-) diff --git a/tools/tools_utils.go b/tools/tools_utils.go index 48531b78..b6f80729 100644 --- a/tools/tools_utils.go +++ b/tools/tools_utils.go @@ -166,31 +166,26 @@ func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) return "", "", err } - var obj any - err = json.Unmarshal(b.Bytes(), &obj) - if err != nil { + // Extract JSON object between curly braces + // JSON arrays are also valid as they will not be repeated in the template + output := b.String() + start := strings.Index(output, "{") + end := strings.LastIndex(output, "}") + if start == -1 || end == -1 || start > end { + return "", "", errors.New("no valid JSON object found in template output") + } + jsonStr := output[start : end+1] + + var obj map[string]any + if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { return "", "", err } - var objs []map[string]any - switch v := obj.(type) { - case map[string]any: - objs = []map[string]any{v} - case []map[string]any: - objs = v - case []any: - objs = collect(v) - } - if len(objs) == 0 { - return "", "", errors.New("no template objects found") - } - - // find the keys that correspond to the name and arguments fields - for k, v := range objs[0] { - switch v.(type) { - case string: + // Find name and arguments fields + for k, v := range obj { + if str, ok := v.(string); ok && str == "@@name@@" { name = k - case map[string]any: + } else if _, ok := v.(map[string]any); ok { arguments = k } } diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go index 769183b7..e346117a 100644 --- a/tools/tools_utils_test.go +++ b/tools/tools_utils_test.go @@ -271,74 +271,99 @@ func TestExtractToolArgs(t *testing.T) { cases := []struct { name string template string - want string - ok bool + wantName string + wantArgs string + wantErr bool }{ { - name: "basic tool call with text after", - template: `{{if .ToolCalls}}tool response{{end}}`, - want: "tool response", - ok: true, + name: "basic tool call", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`, + wantName: "name", + wantArgs: "parameters", + wantErr: false, }, { - name: "tool call with mixed content after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, + name: "tool call with whitespace", + template: `{{range .ToolCalls}} + {"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}} +{{end}}`, + wantName: "name", + wantArgs: "parameters", + wantErr: false, }, { - name: "tool call with no text after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, - }, - { - name: "nested tool call", - template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, - want: "[TOOL_CALL]", - ok: true, + name: "tool call with extra content", + template: `Before {{range .ToolCalls}} +{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { name: "no tool calls", template: `{{if .Something}}no tools here{{end}}`, - want: "", - ok: false, + wantName: "", + wantArgs: "", + wantErr: true, }, { name: "empty template", template: ``, - want: "", - ok: false, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "multiple tool calls sections", - template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, - want: "first", - ok: true, + name: "prefix within tool call", + template: `{{- if .ToolCalls }} +{{ range .ToolCalls }} + +{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }}{{- end }}`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { - name: "range over tool calls", - template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, - want: "", - ok: true, + name: "JSON array", + template: `{{ range .ToolCalls }} +[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { - name: "tool calls with pipe delimiters", - template: `{{if .ToolCalls}}<|tool|>{{end}}`, - want: "<|tool|>", - ok: true, + name: "invalid JSON", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "tool calls with nested template", - template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, - want: "", - ok: true, + name: "missing name field", + template: `{{ range .ToolCalls }} +{"parameters": {{ .Function.Arguments }}}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "tool calls with whitespace variations", - template: `{{if .ToolCalls}} tool {{end}}`, - want: " tool ", - ok: true, + name: "missing arguments field", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}"}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, + }, + { + name: "malformed JSON", + template: `{{ range .ToolCalls }} +{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, } @@ -349,12 +374,20 @@ func TestExtractToolArgs(t *testing.T) { t.Fatalf("failed to parse template: %v", err) } - got, ok := extractToolCallsFormat(tmpl) - if got != tt.want { - t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + gotName, gotArgs, err := extractToolArgs(tmpl) + if (err != nil) != tt.wantErr { + t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr) + return } - if ok != tt.ok { - t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + if err != nil { + return + } + + if gotName != tt.wantName { + t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName) + } + if gotArgs != tt.wantArgs { + t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs) } }) } From 5c42800fca4da07d1c362c0f190429993e53c3b5 Mon Sep 17 00:00:00 2001 From: HardCodeDev Date: Sat, 31 May 2025 06:50:16 +0400 Subject: [PATCH 04/26] readme: add SimpleOllamaUnity to community integrations (#10817) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d7cf5bfe..22de41e6 100644 --- a/README.md +++ b/README.md @@ -587,6 +587,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs) +- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime) - [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama) ### Supported backends From 09430011936652cf55925184aaed6f2cebf62a75 Mon Sep 17 00:00:00 2001 From: JasonHonKL <148705846+JasonHonKL@users.noreply.github.com> Date: Thu, 5 Jun 2025 02:39:48 +0800 Subject: [PATCH 05/26] server: add model capabilities to the list endpoint (#10174) --- api/types.go | 13 +++++++------ docs/api.md | 29 +++++++++++++++++++---------- server/routes.go | 14 +++++++++++--- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/api/types.go b/api/types.go index 94d49200..a1f896db 100644 --- a/api/types.go +++ b/api/types.go @@ -457,12 +457,13 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Capabilities []model.Capability `json:"capabilities,omitempty"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. diff --git a/docs/api.md b/docs/api.md index 11eaf73a..31e18bd5 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1157,11 +1157,15 @@ A single JSON object will be returned. { "models": [ { - "name": "deepseek-r1:latest", - "model": "deepseek-r1:latest", - "modified_at": "2025-05-10T08:06:48.639712648-07:00", - "size": 4683075271, - "digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163", + + "model": "codellama:13b", + "modified_at": "2023-11-04T14:56:49.277302595-07:00", + "size": 7365960935, + "digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", + "capabilities": [ + "completion" + ], + "details": { "parent_model": "", "format": "gguf", @@ -1174,11 +1178,16 @@ A single JSON object will be returned. } }, { - "name": "llama3.2:latest", - "model": "llama3.2:latest", - "modified_at": "2025-05-04T17:37:44.706015396-07:00", - "size": 2019393189, - "digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72", + + "model": "llama4:latest", + "modified_at": "2023-12-07T09:32:18.757212583-08:00", + "size": 3825819519, + "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e", + "capabilities": [ + "completion", + "vision" + ], + "details": { "parent_model": "", "format": "gguf", diff --git a/server/routes.go b/server/routes.go index 236f92e2..924ba06c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -928,8 +928,7 @@ func (s *Server) ListHandler(c *gin.Context) { } } - // tag should never be masked - models = append(models, api.ListModelResponse{ + r := api.ListModelResponse{ Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), @@ -942,7 +941,16 @@ func (s *Server) ListHandler(c *gin.Context) { ParameterSize: cf.ModelType, QuantizationLevel: cf.FileType, }, - }) + } + + model, err := GetModel(n.String()) + if err != nil { + slog.Warn("bad model details", "name", n, "error", err) + } else { + r.Capabilities = model.Capabilities() + } + + models = append(models, r) } slices.SortStableFunc(models, func(i, j api.ListModelResponse) int { From 0683efa6379ba69384bb5876f1013d9ef38f1ab0 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 5 Jun 2025 10:22:32 -0700 Subject: [PATCH 06/26] export ThinkingParser --- server/routes.go | 28 ++++++++++++++-------------- server/thinking.go | 26 +++++++++++++------------- server/thinking_test.go | 16 ++++++++-------- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/server/routes.go b/server/routes.go index 924ba06c..d03ac2ec 100644 --- a/server/routes.go +++ b/server/routes.go @@ -282,12 +282,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt = b.String() } - var thinkingState *thinkingParser + var thinkingState *ThinkingParser openingTag, closingTag := inferThinkingTags(m.Template.Template) if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { - thinkingState = &thinkingParser{ - openingTag: openingTag, - closingTag: closingTag, + thinkingState = &ThinkingParser{ + OpeningTag: openingTag, + ClosingTag: closingTag, } } @@ -316,7 +316,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if thinkingState != nil { - thinking, content := thinkingState.addContent(cr.Content) + thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content } @@ -1522,12 +1522,12 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - var thinkingState *thinkingParser + var thinkingState *ThinkingParser openingTag, closingTag := inferThinkingTags(m.Template.Template) if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { - thinkingState = &thinkingParser{ - openingTag: openingTag, - closingTag: closingTag, + thinkingState = &ThinkingParser{ + OpeningTag: openingTag, + ClosingTag: closingTag, } } @@ -1565,7 +1565,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } if thinkingState != nil { - thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content) + thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) if thinkingContent == "" && remainingContent == "" && !r.Done { // need to accumulate more to decide what to send return @@ -1676,11 +1676,11 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { // change the user output), we should probably perform this filtering // for all thinking models (not just qwen3 & deepseek-r1) since it tends // to save tokens and improve quality. - thinkingState := &thinkingParser{ - openingTag: "", - closingTag: "", + thinkingState := &ThinkingParser{ + OpeningTag: "", + ClosingTag: "", } - _, content := thinkingState.addContent(msg.Content) + _, content := thinkingState.AddContent(msg.Content) msgs[i].Content = content } } diff --git a/server/thinking.go b/server/thinking.go index 2213b6b6..4ef3c184 100644 --- a/server/thinking.go +++ b/server/thinking.go @@ -46,17 +46,17 @@ func (s thinkingState) String() string { } } -type thinkingParser struct { +type ThinkingParser struct { state thinkingState - openingTag string - closingTag string + OpeningTag string + ClosingTag string acc strings.Builder } -// addContent returns the thinking content and the non-thinking content that +// AddContent returns the thinking content and the non-thinking content that // should be immediately sent to the user. It will internally buffer if it needs // to see more raw content to disambiguate -func (s *thinkingParser) addContent(content string) (string, string) { +func (s *ThinkingParser) AddContent(content string) (string, string) { s.acc.WriteString(content) var thinkingSb, remainingSb strings.Builder @@ -76,12 +76,12 @@ func (s *thinkingParser) addContent(content string) (string, string) { } // the additional bool return is true iff we should continue eating -func eat(s *thinkingParser) (string, string, bool) { +func eat(s *ThinkingParser) (string, string, bool) { switch s.state { case thinkingState_LookingForOpening: trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) - if strings.HasPrefix(trimmed, s.openingTag) { - after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag) + if strings.HasPrefix(trimmed, s.OpeningTag) { + after := strings.Join(strings.Split(trimmed, s.OpeningTag)[1:], s.OpeningTag) after = strings.TrimLeftFunc(after, unicode.IsSpace) // after might contain more than just thinking tokens, so we continue // parsing instead of returning it as thinking tokens here @@ -93,7 +93,7 @@ func eat(s *thinkingParser) (string, string, bool) { s.state = thinkingState_Thinking } return "", "", true - } else if strings.HasPrefix(s.openingTag, trimmed) { + } else if strings.HasPrefix(s.OpeningTag, trimmed) { // partial opening seen, so let's keep accumulating return "", "", false } else if trimmed == "" { @@ -119,10 +119,10 @@ func eat(s *thinkingParser) (string, string, bool) { } case thinkingState_Thinking: acc := s.acc.String() - if strings.Contains(acc, s.closingTag) { - split := strings.Split(acc, s.closingTag) + if strings.Contains(acc, s.ClosingTag) { + split := strings.Split(acc, s.ClosingTag) thinking := split[0] - remaining := strings.Join(split[1:], s.closingTag) + remaining := strings.Join(split[1:], s.ClosingTag) remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) s.acc.Reset() if remaining == "" { @@ -131,7 +131,7 @@ func eat(s *thinkingParser) (string, string, bool) { s.state = thinkingState_ThinkingDone } return thinking, remaining, false - } else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 { + } else if overlapLen := overlap(acc, s.ClosingTag); overlapLen > 0 { thinking := acc[:len(acc)-overlapLen] remaining := acc[len(acc)-overlapLen:] s.acc.Reset() diff --git a/server/thinking_test.go b/server/thinking_test.go index a2055635..90d3f961 100644 --- a/server/thinking_test.go +++ b/server/thinking_test.go @@ -26,11 +26,11 @@ func TestExtractThinking(t *testing.T) { }, } for i, tt := range tests { - parser := thinkingParser{ - openingTag: "", - closingTag: "", + parser := ThinkingParser{ + OpeningTag: "", + ClosingTag: "", } - gotThinking, gotContent := parser.addContent(tt.in) + gotThinking, gotContent := parser.AddContent(tt.in) if gotContent != tt.wantContent || gotThinking != tt.wantThink { t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent) } @@ -259,15 +259,15 @@ func TestThinkingStreaming(t *testing.T) { } for _, c := range cases { - parser := thinkingParser{ - openingTag: "", - closingTag: "", + parser := ThinkingParser{ + OpeningTag: "", + ClosingTag: "", } if c.skip { continue } for i, step := range c.steps { - thinking, content := parser.addContent(step.input) + thinking, content := parser.AddContent(step.input) if content != step.wantContent || thinking != step.wantThinking { t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking) } From c6a6d7294dd50b9216918fe72fd92bc4ae572ac0 Mon Sep 17 00:00:00 2001 From: Hunter Wittenborn Date: Fri, 6 Jun 2025 11:07:29 -0500 Subject: [PATCH 07/26] docs: fix typo in development.md (#10998) --- docs/development.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/development.md b/docs/development.md index cf6d91e2..24bcba19 100644 --- a/docs/development.md +++ b/docs/development.md @@ -118,7 +118,7 @@ To run tests, use `go test`: go test ./... ``` -> NOTE: In rare cirumstances, you may nedd to change a package using the new +> NOTE: In rare cirumstances, you may need to change a package using the new > "synctest" package in go1.24. > > If you do not have the "synctest" package enabled, you will not see build or From a3b6886b7da0339e63ebf41e6ba5c6b06438a123 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Fri, 6 Jun 2025 12:02:20 -0700 Subject: [PATCH 08/26] move thinking logic into its own package (#10990) move thinking logic into its own package --- server/images.go | 3 +- server/routes.go | 15 +- server/thinking.go => thinking/parser.go | 137 +----------------- .../parser_test.go | 131 +---------------- thinking/template.go | 134 +++++++++++++++++ thinking/template_test.go | 130 +++++++++++++++++ 6 files changed, 281 insertions(+), 269 deletions(-) rename server/thinking.go => thinking/parser.go (59%) rename server/thinking_test.go => thinking/parser_test.go (66%) create mode 100644 thinking/template.go create mode 100644 thinking/template_test.go diff --git a/server/images.go b/server/images.go index 58fb87dc..d6cceff4 100644 --- a/server/images.go +++ b/server/images.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/thinking" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -113,7 +114,7 @@ func (m *Model) Capabilities() []model.Capability { } // Check for thinking capability - openingTag, closingTag := inferThinkingTags(m.Template.Template) + openingTag, closingTag := thinking.InferTags(m.Template.Template) if openingTag != "" && closingTag != "" { capabilities = append(capabilities, model.CapabilityThinking) } diff --git a/server/routes.go b/server/routes.go index d03ac2ec..70cb6cef 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,6 +37,7 @@ import ( "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/thinking" "github.com/ollama/ollama/tools" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" @@ -282,10 +283,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt = b.String() } - var thinkingState *ThinkingParser - openingTag, closingTag := inferThinkingTags(m.Template.Template) + var thinkingState *thinking.Parser + openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { - thinkingState = &ThinkingParser{ + thinkingState = &thinking.Parser{ OpeningTag: openingTag, ClosingTag: closingTag, } @@ -1522,10 +1523,10 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - var thinkingState *ThinkingParser - openingTag, closingTag := inferThinkingTags(m.Template.Template) + var thinkingState *thinking.Parser + openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { - thinkingState = &ThinkingParser{ + thinkingState = &thinking.Parser{ OpeningTag: openingTag, ClosingTag: closingTag, } @@ -1676,7 +1677,7 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { // change the user output), we should probably perform this filtering // for all thinking models (not just qwen3 & deepseek-r1) since it tends // to save tokens and improve quality. - thinkingState := &ThinkingParser{ + thinkingState := &thinking.Parser{ OpeningTag: "", ClosingTag: "", } diff --git a/server/thinking.go b/thinking/parser.go similarity index 59% rename from server/thinking.go rename to thinking/parser.go index 4ef3c184..a4d05e35 100644 --- a/server/thinking.go +++ b/thinking/parser.go @@ -1,9 +1,7 @@ -package server +package thinking import ( "strings" - "text/template" - "text/template/parse" "unicode" ) @@ -46,7 +44,7 @@ func (s thinkingState) String() string { } } -type ThinkingParser struct { +type Parser struct { state thinkingState OpeningTag string ClosingTag string @@ -56,7 +54,7 @@ type ThinkingParser struct { // AddContent returns the thinking content and the non-thinking content that // should be immediately sent to the user. It will internally buffer if it needs // to see more raw content to disambiguate -func (s *ThinkingParser) AddContent(content string) (string, string) { +func (s *Parser) AddContent(content string) (string, string) { s.acc.WriteString(content) var thinkingSb, remainingSb strings.Builder @@ -76,7 +74,7 @@ func (s *ThinkingParser) AddContent(content string) (string, string) { } // the additional bool return is true iff we should continue eating -func eat(s *ThinkingParser) (string, string, bool) { +func eat(s *Parser) (string, string, bool) { switch s.state { case thinkingState_LookingForOpening: trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) @@ -171,130 +169,3 @@ func overlap(s, delim string) int { } return 0 } - -func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) { - if n == nil { - return - } - shouldContinue := enterFn(n) - if !shouldContinue { - return - } - switch x := n.(type) { - case *parse.ListNode: - for _, c := range x.Nodes { - templateVisit(c, enterFn, exitFn) - } - case *parse.BranchNode: - if x.Pipe != nil { - templateVisit(x.Pipe, enterFn, exitFn) - } - if x.List != nil { - templateVisit(x.List, enterFn, exitFn) - } - if x.ElseList != nil { - templateVisit(x.ElseList, enterFn, exitFn) - } - case *parse.ActionNode: - templateVisit(x.Pipe, enterFn, exitFn) - case *parse.WithNode: - templateVisit(&x.BranchNode, enterFn, exitFn) - case *parse.RangeNode: - templateVisit(&x.BranchNode, enterFn, exitFn) - case *parse.IfNode: - templateVisit(&x.BranchNode, enterFn, exitFn) - case *parse.TemplateNode: - templateVisit(x.Pipe, enterFn, exitFn) - case *parse.PipeNode: - for _, c := range x.Cmds { - templateVisit(c, enterFn, exitFn) - } - case *parse.CommandNode: - for _, a := range x.Args { - templateVisit(a, enterFn, exitFn) - } - // text, field, number, etc. are leaves – nothing to recurse into - } - if exitFn != nil { - exitFn(n) - } -} - -// We use a heuristic to infer the tags that surround thinking traces: -// We look for a range node that iterates over "Messages" and then look for a -// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest -// ListNode and take the first and last TextNodes as the opening and closing -// tags. -func inferThinkingTags(t *template.Template) (string, string) { - ancestors := []parse.Node{} - - openingTag := "" - closingTag := "" - - enterFn := func(n parse.Node) bool { - ancestors = append(ancestors, n) - - switch x := n.(type) { - case *parse.FieldNode: - if len(x.Ident) > 0 && x.Ident[0] == "Thinking" { - var mostRecentRange *parse.RangeNode - for i := len(ancestors) - 1; i >= 0; i-- { - if r, ok := ancestors[i].(*parse.RangeNode); ok { - mostRecentRange = r - break - } - } - if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") { - return true - } - - // TODO(drifkin): to be more robust, check that it's in the action - // part, not the `if`'s pipeline part. We do match on the nearest list - // that starts and ends with text nodes, which makes this not strictly - // necessary for our heuristic - - // go up to the nearest ancestor that is a *parse.ListNode - for i := len(ancestors) - 1; i >= 0; i-- { - if l, ok := ancestors[i].(*parse.ListNode); ok { - firstNode := l.Nodes[0] - if t, ok := firstNode.(*parse.TextNode); ok { - openingTag = strings.TrimSpace(t.String()) - } - lastNode := l.Nodes[len(l.Nodes)-1] - if t, ok := lastNode.(*parse.TextNode); ok { - closingTag = strings.TrimSpace(t.String()) - } - - break - } - } - } - } - - return true - } - - exitFn := func(n parse.Node) { - ancestors = ancestors[:len(ancestors)-1] - } - - templateVisit(t.Root, enterFn, exitFn) - - return openingTag, closingTag -} - -// checks to see if the given field name is present in the pipeline of the given range node -func rangeUsesField(rangeNode *parse.RangeNode, field string) bool { - found := false - enterFn := func(n parse.Node) bool { - switch x := n.(type) { - case *parse.FieldNode: - if x.Ident[0] == field { - found = true - } - } - return true - } - templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil) - return found -} diff --git a/server/thinking_test.go b/thinking/parser_test.go similarity index 66% rename from server/thinking_test.go rename to thinking/parser_test.go index 90d3f961..78c297cd 100644 --- a/server/thinking_test.go +++ b/thinking/parser_test.go @@ -1,8 +1,7 @@ -package server +package thinking import ( "testing" - "text/template" ) func TestExtractThinking(t *testing.T) { @@ -26,7 +25,7 @@ func TestExtractThinking(t *testing.T) { }, } for i, tt := range tests { - parser := ThinkingParser{ + parser := Parser{ OpeningTag: "", ClosingTag: "", } @@ -259,7 +258,7 @@ func TestThinkingStreaming(t *testing.T) { } for _, c := range cases { - parser := ThinkingParser{ + parser := Parser{ OpeningTag: "", ClosingTag: "", } @@ -277,127 +276,3 @@ func TestThinkingStreaming(t *testing.T) { } } } - -func TestInferThinkingTags(t *testing.T) { - cases := []struct { - desc string - tmplString string - wantOpeningTag string - wantClosingTag string - }{ - { - desc: "basic", - tmplString: ` - {{ if .Thinking}} - /think - {{ end }} - {{- range $i, $_ := .Messages }} - {{- $last := eq (len (slice $.Messages $i)) 1 -}} - {{ if and $last .Thinking }} - {{ .Thinking }} - {{ end }} - {{ end }} - `, - wantOpeningTag: "", - wantClosingTag: "", - }, - { - desc: "doubly nested range", - tmplString: ` - {{ if .Thinking}} - /think - {{ end }} - {{- range $i, $_ := .Messages }} - {{- range $j, $_ := .NotMessages }} - {{- $last := eq (len (slice $.Messages $i)) 1 -}} - {{ if and $last .Thinking }} - {{ .Thinking }} - {{ end }} - {{ end }} - {{ end }} - `, - wantOpeningTag: "", - wantClosingTag: "", - }, - { - desc: "whitespace is trimmed", - tmplString: ` - {{ if .Thinking}} - /think - {{ end }} - {{- range $i, $_ := .Messages }} - {{- $last := eq (len (slice $.Messages $i)) 1 -}} - {{ if and $last .Thinking }} - Some text before {{ .Thinking }} Some text after - {{ end }} - {{ end }} - `, - wantOpeningTag: "Some text before", - wantClosingTag: "Some text after", - }, - { - desc: "qwen3", - tmplString: ` -{{- if or .System .Tools .Thinking }}<|im_start|>system -{{- if .System }} -{{ .System }} -{{- end }} -{{- if .Tools }} - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{{- range .Tools }} -{"type": "function", "function": {{ .Function }}} -{{- end }} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -{{- end }} -{{- if .Thinking }} -/think -{{- else }} -/no_think -{{- end }}<|im_end|> -{{ end }} -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}<|im_start|>user -{{ .Content }}<|im_end|> -{{ else if eq .Role "assistant" }}<|im_start|>assistant -{{ if and $last .Thinking }} -{{ .Thinking }} -{{ end }} -{{ if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }} -{{- end }}{{ if not $last }}<|im_end|> -{{ end }} -{{- else if eq .Role "tool" }}<|im_start|>user - -{{ .Content }} -<|im_end|> -{{ end }} -{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant -{{ end }} -{{- end }} - `, - wantOpeningTag: "", - wantClosingTag: "", - }, - } - for _, c := range cases { - tmpl := template.Must(template.New("test").Parse(c.tmplString)) - openingTag, closingTag := inferThinkingTags(tmpl) - if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag { - t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag) - } - } -} diff --git a/thinking/template.go b/thinking/template.go new file mode 100644 index 00000000..20bd65ec --- /dev/null +++ b/thinking/template.go @@ -0,0 +1,134 @@ +package thinking + +import ( + "strings" + "text/template" + "text/template/parse" +) + +func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) { + if n == nil { + return + } + shouldContinue := enterFn(n) + if !shouldContinue { + return + } + switch x := n.(type) { + case *parse.ListNode: + for _, c := range x.Nodes { + templateVisit(c, enterFn, exitFn) + } + case *parse.BranchNode: + if x.Pipe != nil { + templateVisit(x.Pipe, enterFn, exitFn) + } + if x.List != nil { + templateVisit(x.List, enterFn, exitFn) + } + if x.ElseList != nil { + templateVisit(x.ElseList, enterFn, exitFn) + } + case *parse.ActionNode: + templateVisit(x.Pipe, enterFn, exitFn) + case *parse.WithNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.RangeNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.IfNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.TemplateNode: + templateVisit(x.Pipe, enterFn, exitFn) + case *parse.PipeNode: + for _, c := range x.Cmds { + templateVisit(c, enterFn, exitFn) + } + case *parse.CommandNode: + for _, a := range x.Args { + templateVisit(a, enterFn, exitFn) + } + // text, field, number, etc. are leaves – nothing to recurse into + } + if exitFn != nil { + exitFn(n) + } +} + +// InferTags uses a heuristic to infer the tags that surround thinking traces: +// We look for a range node that iterates over "Messages" and then look for a +// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest +// ListNode and take the first and last TextNodes as the opening and closing +// tags. +func InferTags(t *template.Template) (string, string) { + ancestors := []parse.Node{} + + openingTag := "" + closingTag := "" + + enterFn := func(n parse.Node) bool { + ancestors = append(ancestors, n) + + switch x := n.(type) { + case *parse.FieldNode: + if len(x.Ident) > 0 && x.Ident[0] == "Thinking" { + var mostRecentRange *parse.RangeNode + for i := len(ancestors) - 1; i >= 0; i-- { + if r, ok := ancestors[i].(*parse.RangeNode); ok { + mostRecentRange = r + break + } + } + if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") { + return true + } + + // TODO(drifkin): to be more robust, check that it's in the action + // part, not the `if`'s pipeline part. We do match on the nearest list + // that starts and ends with text nodes, which makes this not strictly + // necessary for our heuristic + + // go up to the nearest ancestor that is a *parse.ListNode + for i := len(ancestors) - 1; i >= 0; i-- { + if l, ok := ancestors[i].(*parse.ListNode); ok { + firstNode := l.Nodes[0] + if t, ok := firstNode.(*parse.TextNode); ok { + openingTag = strings.TrimSpace(t.String()) + } + lastNode := l.Nodes[len(l.Nodes)-1] + if t, ok := lastNode.(*parse.TextNode); ok { + closingTag = strings.TrimSpace(t.String()) + } + + break + } + } + } + } + + return true + } + + exitFn := func(n parse.Node) { + ancestors = ancestors[:len(ancestors)-1] + } + + templateVisit(t.Root, enterFn, exitFn) + + return openingTag, closingTag +} + +// checks to see if the given field name is present in the pipeline of the given range node +func rangeUsesField(rangeNode *parse.RangeNode, field string) bool { + found := false + enterFn := func(n parse.Node) bool { + switch x := n.(type) { + case *parse.FieldNode: + if x.Ident[0] == field { + found = true + } + } + return true + } + templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil) + return found +} diff --git a/thinking/template_test.go b/thinking/template_test.go new file mode 100644 index 00000000..e63558e2 --- /dev/null +++ b/thinking/template_test.go @@ -0,0 +1,130 @@ +package thinking + +import ( + "testing" + "text/template" +) + +func TestInferThinkingTags(t *testing.T) { + cases := []struct { + desc string + tmplString string + wantOpeningTag string + wantClosingTag string + }{ + { + desc: "basic", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + {{ .Thinking }} + {{ end }} + {{ end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + { + desc: "doubly nested range", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- range $j, $_ := .NotMessages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + {{ .Thinking }} + {{ end }} + {{ end }} + {{ end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + { + desc: "whitespace is trimmed", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + Some text before {{ .Thinking }} Some text after + {{ end }} + {{ end }} + `, + wantOpeningTag: "Some text before", + wantClosingTag: "Some text after", + }, + { + desc: "qwen3", + tmplString: ` +{{- if or .System .Tools .Thinking }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }} +{{- if .Thinking }} +/think +{{- else }} +/no_think +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if and $last .Thinking }} +{{ .Thinking }} +{{ end }} +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + } + for _, c := range cases { + tmpl := template.Must(template.New("test").Parse(c.tmplString)) + openingTag, closingTag := InferTags(tmpl) + if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag { + t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag) + } + } +} From 2ae65ae471c9d51d343f401da16c05b98b99a842 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 6 Jun 2025 14:06:09 -0700 Subject: [PATCH 09/26] win: handle more than 2048 processes (#10997) Fix an array out of bounds crash --- cmd/start_windows.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cmd/start_windows.go b/cmd/start_windows.go index bcc51057..1b648d9d 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -74,7 +74,16 @@ func isProcRunning(procName string) []uint32 { slog.Debug("failed to check for running installers", "error", err) return nil } - pids = pids[:ret] + if ret > uint32(len(pids)) { + pids = make([]uint32, ret+10) + if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 { + slog.Debug("failed to check for running installers", "error", err) + return nil + } + } + if ret < uint32(len(pids)) { + pids = pids[:ret] + } var matches []uint32 for _, pid := range pids { if pid == 0 { From a8ed68bd9383ffec346fd1b3cf60d94c032bbec8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 6 Jun 2025 14:06:29 -0700 Subject: [PATCH 10/26] launch app hidden (#10962) When starting the app in the background, start it hidden. --- cmd/start_darwin.go | 2 +- cmd/start_windows.go | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go index 1a9a1ae8..3cb726ea 100644 --- a/cmd/start_darwin.go +++ b/cmd/start_darwin.go @@ -23,7 +23,7 @@ func startApp(ctx context.Context, client *api.Client) error { return errors.New("could not find ollama app") } path := strings.Split(link, "Ollama.app") - if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil { + if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil { return err } return waitForServer(ctx, client) diff --git a/cmd/start_windows.go b/cmd/start_windows.go index 1b648d9d..635b5077 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -45,14 +45,11 @@ func startApp(ctx context.Context, client *api.Client) error { } } } - // log.Printf("XXX attempting to start app %s", appExe) cmd_path := "c:\\Windows\\system32\\cmd.exe" - cmd := exec.Command(cmd_path, "/c", appExe) - // TODO - these hide flags aren't working - still pops up a command window for some reason + cmd := exec.Command(cmd_path, "/c", appExe, "hidden") cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true} - // TODO this didn't help either... cmd.Stdin = strings.NewReader("") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr From 09d308d6b6c7995e3fb565e0ecfa184d49205f00 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 6 Jun 2025 23:29:14 -0400 Subject: [PATCH 11/26] Revert "server: add model capabilities to the list endpoint (#10174)" (#11004) This reverts commit 09430011936652cf55925184aaed6f2cebf62a75. --- api/types.go | 13 ++++++------- docs/api.md | 29 ++++++++++------------------- server/routes.go | 14 +++----------- 3 files changed, 19 insertions(+), 37 deletions(-) diff --git a/api/types.go b/api/types.go index a1f896db..94d49200 100644 --- a/api/types.go +++ b/api/types.go @@ -457,13 +457,12 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Capabilities []model.Capability `json:"capabilities,omitempty"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. diff --git a/docs/api.md b/docs/api.md index 31e18bd5..11eaf73a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1157,15 +1157,11 @@ A single JSON object will be returned. { "models": [ { - - "model": "codellama:13b", - "modified_at": "2023-11-04T14:56:49.277302595-07:00", - "size": 7365960935, - "digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", - "capabilities": [ - "completion" - ], - + "name": "deepseek-r1:latest", + "model": "deepseek-r1:latest", + "modified_at": "2025-05-10T08:06:48.639712648-07:00", + "size": 4683075271, + "digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163", "details": { "parent_model": "", "format": "gguf", @@ -1178,16 +1174,11 @@ A single JSON object will be returned. } }, { - - "model": "llama4:latest", - "modified_at": "2023-12-07T09:32:18.757212583-08:00", - "size": 3825819519, - "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e", - "capabilities": [ - "completion", - "vision" - ], - + "name": "llama3.2:latest", + "model": "llama3.2:latest", + "modified_at": "2025-05-04T17:37:44.706015396-07:00", + "size": 2019393189, + "digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72", "details": { "parent_model": "", "format": "gguf", diff --git a/server/routes.go b/server/routes.go index 70cb6cef..8eda5c73 100644 --- a/server/routes.go +++ b/server/routes.go @@ -929,7 +929,8 @@ func (s *Server) ListHandler(c *gin.Context) { } } - r := api.ListModelResponse{ + // tag should never be masked + models = append(models, api.ListModelResponse{ Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), @@ -942,16 +943,7 @@ func (s *Server) ListHandler(c *gin.Context) { ParameterSize: cf.ModelType, QuantizationLevel: cf.FileType, }, - } - - model, err := GetModel(n.String()) - if err != nil { - slog.Warn("bad model details", "name", n, "error", err) - } else { - r.Capabilities = model.Capabilities() - } - - models = append(models, r) + }) } slices.SortStableFunc(models, func(i, j api.ListModelResponse) int { From fc0309615e42c32989e060e733d871e16617874e Mon Sep 17 00:00:00 2001 From: Krzysztof Jeziorny <872730+krzysztofjeziorny@users.noreply.github.com> Date: Sat, 7 Jun 2025 05:30:04 +0200 Subject: [PATCH 12/26] docs: update link to AMD drivers in linux.md (#10973) --- docs/linux.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/linux.md b/docs/linux.md index 2dda87f3..72a5ff01 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -112,8 +112,8 @@ sudo systemctl status ollama > While AMD has contributed the `amdgpu` driver upstream to the official linux > kernel source, the version is older and may not support all ROCm features. We > recommend you install the latest driver from -> https://www.amd.com/en/support/linux-drivers for best support of your Radeon -> GPU. +> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support +> of your Radeon GPU. ## Customizing From feeabdadd2b272b40747f3e7e74957c40ba2800c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sun, 8 Jun 2025 09:34:52 -0700 Subject: [PATCH 13/26] spawn desktop quickly (#11011) Give the desktop app a hint to start fast. --- cmd/start_darwin.go | 2 +- cmd/start_windows.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go index 3cb726ea..83af1235 100644 --- a/cmd/start_darwin.go +++ b/cmd/start_darwin.go @@ -23,7 +23,7 @@ func startApp(ctx context.Context, client *api.Client) error { return errors.New("could not find ollama app") } path := strings.Split(link, "Ollama.app") - if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil { + if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app", "--args", "--fast-startup").Run(); err != nil { return err } return waitForServer(ctx, client) diff --git a/cmd/start_windows.go b/cmd/start_windows.go index 635b5077..9505e1bb 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -47,7 +47,7 @@ func startApp(ctx context.Context, client *api.Client) error { } cmd_path := "c:\\Windows\\system32\\cmd.exe" - cmd := exec.Command(cmd_path, "/c", appExe, "hidden") + cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup") cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true} cmd.Stdin = strings.NewReader("") From 82ad1dbc07dc2db39c6f502eb148ff3ce00d96b8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 9 Jun 2025 16:29:57 -0700 Subject: [PATCH 14/26] mac: handle "keep" named apps (#11031) When a user elects to keep the existing app, the new Ollama is named `Ollama 2.app` This fixes the app startup flow to handle this naming pattern. --- cmd/start_darwin.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go index 83af1235..05a4551e 100644 --- a/cmd/start_darwin.go +++ b/cmd/start_darwin.go @@ -5,7 +5,7 @@ import ( "errors" "os" "os/exec" - "strings" + "regexp" "github.com/ollama/ollama/api" ) @@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error { if err != nil { return err } - if !strings.Contains(link, "Ollama.app") { + r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`) + m := r.FindStringSubmatch(link) + if len(m) != 1 { return errors.New("could not find ollama app") } - path := strings.Split(link, "Ollama.app") - if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app", "--args", "--fast-startup").Run(); err != nil { + if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil { return err } return waitForServer(ctx, client) From f63d7f68eb206cf403fb9c7dca7978d16204e268 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 10 Jun 2025 09:33:54 -0700 Subject: [PATCH 15/26] readme: update quickstart example to Gemma 3 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 22de41e6..af064a63 100644 --- a/README.md +++ b/README.md @@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla ## Quickstart -To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2): +To run and chat with [Llama 3.2](https://ollama.com/library/gemma3): ```shell -ollama run llama3.2 +ollama run gemma3 ``` ## Model library From af21a5ac397c1d2ce62881e0962bbff9d6da31f2 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 10 Jun 2025 09:34:23 -0700 Subject: [PATCH 16/26] readme: update quickstart link text to Gemma 3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index af064a63..90e41be8 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla ## Quickstart -To run and chat with [Llama 3.2](https://ollama.com/library/gemma3): +To run and chat with [Gemma 3](https://ollama.com/library/gemma3): ```shell ollama run gemma3 From deaabe292d86b712e061bebe7fdd6be6690f539b Mon Sep 17 00:00:00 2001 From: Attogram Project Date: Tue, 10 Jun 2025 23:14:51 +0200 Subject: [PATCH 17/26] readme: add ollama-multirun to community integrations (#11038) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 90e41be8..ffaec628 100644 --- a/README.md +++ b/README.md @@ -451,6 +451,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal. - [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform) - [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples +- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/)) ### Apple Vision Pro From 2e77aa1ae70372388bd4b08b9957e5198d566a22 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 11 Jun 2025 12:10:15 -0700 Subject: [PATCH 18/26] use nn.Linear in place of ml.Tensor (#11049) while nn.Linear.Forward isn't applicable for sparse MLP, it's still a nice container for the tensors --- model/models/llama4/model_text.go | 12 ++++++------ model/models/qwen3/model.go | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 27935f40..045ab403 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp } type TextExperts struct { - Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` - Up ml.Tensor `gguf:"ffn_up_exps.weight"` - Down ml.Tensor `gguf:"ffn_down_exps.weight"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) + gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) + downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 1930da7e..7a83e0d0 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -66,9 +66,9 @@ type MLP interface { type sparse struct { Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` - Up ml.Tensor `gguf:"ffn_up_exps.weight"` - Down ml.Tensor `gguf:"ffn_down_exps.weight"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts) + upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) hiddenStates = hiddenStates.SILU(ctx) hiddenStates = hiddenStates.Mul(ctx, upStates) - experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) From 0dabb4ef6a1aab240a59b6bb4ef82372d335e3a9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 11 Jun 2025 12:10:35 -0700 Subject: [PATCH 19/26] skip tokenizer.model if possible (#11050) if tokenizer.json is already copied, skip tokenizer.model --- parser/parser.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 96eae9c0..d40a79c2 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -292,13 +292,18 @@ func filesForModel(path string) ([]string, error) { } files = append(files, js...) - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { - // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob - // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) + // only include tokenizer.model is tokenizer.json is not present + if !slices.ContainsFunc(files, func(s string) bool { + return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") + }) { + if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob + // tokenizer.model might be a unresolved git lfs reference; error if it is + files = append(files, tks...) + } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) + files = append(files, tks...) + } } return files, nil From 45f56355d557b7130c7c07bbd6e1b634a758d946 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 11 Jun 2025 12:10:54 -0700 Subject: [PATCH 20/26] feat: uneven splits (#11048) The current splitDim function only operates on tensors that are split evenly which isn't always the case, e.g. a QKV tensor. This change allows the function to be used for arbitrary splits --- convert/convert_qwen25vl.go | 10 +- convert/tensor.go | 50 ++++-- convert/tensor_test.go | 304 ++++++++++++++++++++++++++++++++++++ 3 files changed, 344 insertions(+), 20 deletions(-) create mode 100644 convert/tensor_test.go diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index c2d5a633..6e4c9640 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor { for _, t := range ts { if strings.Contains(t.Name(), "patch_embed.proj") { for t := range splitDim(t, 2, - strings.NewReplacer("patch_embed.proj", "patch_embd_0"), - strings.NewReplacer("patch_embed.proj", "patch_embd_1"), + split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")}, + split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")}, ) { t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 }) out = append(out, t) } } else if strings.Contains(t.Name(), "attn.qkv") { out = append(out, slices.Collect(splitDim(t, 0, - strings.NewReplacer("attn.qkv", "attn_q"), - strings.NewReplacer("attn.qkv", "attn_k"), - strings.NewReplacer("attn.qkv", "attn_v"), + split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")}, + split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")}, + split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")}, ))...) } else { out = append(out, &ggml.Tensor{ diff --git a/convert/tensor.go b/convert/tensor.go index ffb22ead..9d6919e3 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -1,53 +1,73 @@ package convert import ( + "cmp" "iter" "slices" "strings" - "github.com/ollama/ollama/fs/ggml" "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" ) +type split struct { + *strings.Replacer + dim int + + // fn is an optional function to apply to the tensor after slicing + fn func(tensor.Tensor) (tensor.Tensor, error) +} + // splitDim splits a tensor along a specified dimension into multiple tensors. The dimension -// is split evenly based on the number of replacers provided. -func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] { +// is split evenly based on the number of replacers provided unless a specific count is given. +func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { return func(yield func(*ggml.Tensor) bool) { - for i, replacer := range replacers { + var offset int + for _, split := range splits { + t := t.Clone() shape := slices.Clone(t.Shape()) - shape[dim] = shape[dim] / uint64(len(replacers)) + shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits))) slice := slices.Repeat([]tensor.Slice{nil}, len(shape)) - slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim])) + slice[dim] = tensor.S(offset, offset+int(shape[dim])) + offset += int(shape[dim]) - tt := t.Clone() - tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { + t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { dims := make([]int, len(shape)) for i := range shape { dims[i] = int(shape[i]) } - var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) - t, err := t.Slice(slice...) + var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + tt, err := tt.Slice(slice...) if err != nil { return nil, err } - t = tensor.Materialize(t) + tt = tensor.Materialize(tt) + + if split.fn != nil { + tt, err = split.fn(tt) + if err != nil { + return nil, err + } + } + // flatten tensor so it can be written as a vector - if err := t.Reshape(t.Shape().TotalSize()); err != nil { + if err := tt.Reshape(tt.Shape().TotalSize()); err != nil { return nil, err } - return native.VectorF32(t.(*tensor.Dense)) + return native.VectorF32(tt.(*tensor.Dense)) }) if !yield(&ggml.Tensor{ - Name: replacer.Replace(t.Name()), + Name: split.Replace(t.Name()), Kind: t.Kind(), Shape: shape, - WriterTo: tt, + WriterTo: t, }) { break } diff --git a/convert/tensor_test.go b/convert/tensor_test.go new file mode 100644 index 00000000..ea12d0f5 --- /dev/null +++ b/convert/tensor_test.go @@ -0,0 +1,304 @@ +package convert + +import ( + "bytes" + "encoding/binary" + "io" + "iter" + "slices" + "strings" + "testing" + + "github.com/pdevine/tensor" +) + +type fakeTensor struct { + name string + shape []uint64 + data []float32 + + repacker Repacker +} + +func (f fakeTensor) Name() string { + return f.name +} + +func (f fakeTensor) Shape() []uint64 { + return f.shape +} + +func (f fakeTensor) Kind() uint32 { + return 0 +} + +func (f *fakeTensor) SetRepacker(fn Repacker) { + f.repacker = fn +} + +func (f fakeTensor) Clone() Tensor { + return &fakeTensor{ + name: f.name, + shape: slices.Clone(f.shape), + data: slices.Clone(f.data), + repacker: f.repacker, + } +} + +func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) { + data := f.data + if f.repacker != nil { + data, err = f.repacker(f.name, data, f.shape) + if err != nil { + return 0, err + } + } + + if err := binary.Write(w, binary.LittleEndian, data); err != nil { + return 0, err + } + + return int64(len(data) * 4), nil +} + +func mul(shape []uint64) int { + n := 1 + for _, dim := range shape { + n *= int(dim) + } + return n +} + +func TestSplitDim(t *testing.T) { + r := fakeTensor{ + name: "a.b", + shape: []uint64{3, 4}, + data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + } + + t.Run("no split", func(t *testing.T) { + for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) { + if tt.Name != "x.b" { + t.Fatalf("expected name 'x', got '%s'", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 4}) { + t.Fatalf("expected shape [3, 4], got %v", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) { + t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s) + } + } + }) + + t.Run("even split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y")}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { + t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) { + t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s) + } + } + }) + + t.Run("uneven split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{2, 4}) { + t.Fatal("expected shape [2, 4], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) { + t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{1, 4}) { + t.Fatal("expected shape [1, 4], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{8, 9, 10, 11}) { + t.Fatal("expected data [8, 9, 10, 11], got", f32s) + } + } + }) + + t.Run("split with transpose", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) { + return tensor.Transpose(tt, 1, 0) + }}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { + t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if !slices.Equal(tt.Shape, []uint64{3, 2}) { + t.Fatal("expected shape [3, 2], got", tt.Shape) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) { + t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s) + } + } + }) +} From 6b04cad7e816d1a119559e092d59f4fbaa6c3a0b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 12 Jun 2025 11:04:11 -0700 Subject: [PATCH 21/26] feat: incremental gguf parser (#10822) * incremental gguf parser * gguf: update test to not rely on gguf on disc * re-use existing create gguf * read capabilities from gguf kv * kv exists * update tests * s/doneFunc/successFunc/g * new buffered reader --------- Co-authored-by: Bruce MacDonald --- fs/gguf/gguf.go | 347 ++++++++++++++++++++++++++++++++++++ fs/gguf/gguf_test.go | 249 ++++++++++++++++++++++++++ fs/gguf/keyvalue.go | 90 ++++++++++ fs/gguf/keyvalue_test.go | 208 +++++++++++++++++++++ fs/gguf/lazy.go | 89 +++++++++ fs/gguf/reader.go | 23 +++ fs/gguf/tensor.go | 288 ++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- server/images.go | 24 ++- server/images_test.go | 165 +++++------------ server/quantization_test.go | 12 +- server/sched_test.go | 20 +-- 13 files changed, 1357 insertions(+), 164 deletions(-) create mode 100644 fs/gguf/gguf.go create mode 100644 fs/gguf/gguf_test.go create mode 100644 fs/gguf/keyvalue.go create mode 100644 fs/gguf/keyvalue_test.go create mode 100644 fs/gguf/lazy.go create mode 100644 fs/gguf/reader.go create mode 100644 fs/gguf/tensor.go diff --git a/fs/gguf/gguf.go b/fs/gguf/gguf.go new file mode 100644 index 00000000..ebb9286f --- /dev/null +++ b/fs/gguf/gguf.go @@ -0,0 +1,347 @@ +package gguf + +import ( + "bytes" + "cmp" + "encoding/binary" + "errors" + "fmt" + "io" + "iter" + "os" + "slices" + "strings" +) + +const ( + typeUint8 uint32 = iota + typeInt8 + typeUint16 + typeInt16 + typeUint32 + typeInt32 + typeFloat32 + typeBool + typeString + typeArray + typeUint64 + typeInt64 + typeFloat64 +) + +var ErrUnsupported = errors.New("unsupported") + +type File struct { + Magic [4]byte + Version uint32 + + keyValues *lazy[KeyValue] + tensors *lazy[TensorInfo] + offset int64 + + file *os.File + reader *bufferedReader + bts []byte +} + +func Open(path string) (f *File, err error) { + f = &File{bts: make([]byte, 4096)} + f.file, err = os.Open(path) + if err != nil { + return nil, err + } + + f.reader = newBufferedReader(f.file, 32<<10) + + if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil { + return nil, err + } + + if bytes.Equal(f.Magic[:], []byte("gguf")) { + return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic) + } + + if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil { + return nil, err + } + + if f.Version != 3 { + return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version) + } + + f.tensors, err = newLazy(f, f.readTensor) + if err != nil { + return nil, err + } + + f.tensors.successFunc = func() error { + offset := f.reader.offset + + alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32) + f.offset = offset + (alignment-offset%alignment)%alignment + return nil + } + + f.keyValues, err = newLazy(f, f.readKeyValue) + if err != nil { + return nil, err + } + + return f, nil +} + +func (f *File) readTensor() (TensorInfo, error) { + name, err := readString(f) + if err != nil { + return TensorInfo{}, err + } + + dims, err := read[uint32](f) + if err != nil { + return TensorInfo{}, err + } + + shape := make([]uint64, dims) + for i := range dims { + shape[i], err = read[uint64](f) + if err != nil { + return TensorInfo{}, err + } + } + + type_, err := read[uint32](f) + if err != nil { + return TensorInfo{}, err + } + + offset, err := read[uint64](f) + if err != nil { + return TensorInfo{}, err + } + + return TensorInfo{ + Name: name, + Offset: offset, + Shape: shape, + Type: TensorType(type_), + }, nil +} + +func (f *File) readKeyValue() (KeyValue, error) { + key, err := readString(f) + if err != nil { + return KeyValue{}, err + } + + t, err := read[uint32](f) + if err != nil { + return KeyValue{}, err + } + + value, err := func() (any, error) { + switch t { + case typeUint8: + return read[uint8](f) + case typeInt8: + return read[int8](f) + case typeUint16: + return read[uint16](f) + case typeInt16: + return read[int16](f) + case typeUint32: + return read[uint32](f) + case typeInt32: + return read[int32](f) + case typeUint64: + return read[uint64](f) + case typeInt64: + return read[int64](f) + case typeFloat32: + return read[float32](f) + case typeFloat64: + return read[float64](f) + case typeBool: + return read[bool](f) + case typeString: + return readString(f) + case typeArray: + return readArray(f) + default: + return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) + } + }() + if err != nil { + return KeyValue{}, err + } + + return KeyValue{ + Key: key, + Value: Value{value}, + }, nil +} + +func read[T any](f *File) (t T, err error) { + err = binary.Read(f.reader, binary.LittleEndian, &t) + return t, err +} + +func readString(f *File) (string, error) { + n, err := read[uint64](f) + if err != nil { + return "", err + } + + if int(n) > len(f.bts) { + f.bts = make([]byte, n) + } + + bts := f.bts[:n] + if _, err := io.ReadFull(f.reader, bts); err != nil { + return "", err + } + defer clear(bts) + + return string(bts), nil +} + +func readArray(f *File) (any, error) { + t, err := read[uint32](f) + if err != nil { + return nil, err + } + + n, err := read[uint64](f) + if err != nil { + return nil, err + } + + switch t { + case typeUint8: + return readArrayData[uint8](f, n) + case typeInt8: + return readArrayData[int8](f, n) + case typeUint16: + return readArrayData[uint16](f, n) + case typeInt16: + return readArrayData[int16](f, n) + case typeUint32: + return readArrayData[uint32](f, n) + case typeInt32: + return readArrayData[int32](f, n) + case typeUint64: + return readArrayData[uint64](f, n) + case typeInt64: + return readArrayData[int64](f, n) + case typeFloat32: + return readArrayData[float32](f, n) + case typeFloat64: + return readArrayData[float64](f, n) + case typeBool: + return readArrayData[bool](f, n) + case typeString: + return readArrayString(f, n) + default: + return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) + } +} + +func readArrayData[T any](f *File, n uint64) (s []T, err error) { + s = make([]T, n) + for i := range n { + e, err := read[T](f) + if err != nil { + return nil, err + } + + s[i] = e + } + + return s, nil +} + +func readArrayString(f *File, n uint64) (s []string, err error) { + s = make([]string, n) + for i := range n { + e, err := readString(f) + if err != nil { + return nil, err + } + + s[i] = e + } + + return s, nil +} + +func (f *File) Close() error { + f.keyValues.stop() + f.tensors.stop() + return f.file.Close() +} + +func (f *File) KeyValue(key string) KeyValue { + if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") { + key = f.KeyValue("general.architecture").String() + "." + key + } + + if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool { + return kv.Key == key + }); index >= 0 { + return f.keyValues.values[index] + } + + for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() { + if keyValue.Key == key { + return keyValue + } + } + + return KeyValue{} +} + +func (f *File) NumKeyValues() int { + return int(f.keyValues.count) +} + +func (f *File) KeyValues() iter.Seq2[int, KeyValue] { + return f.keyValues.All() +} + +func (f *File) TensorInfo(name string) TensorInfo { + if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool { + return t.Name == name + }); index >= 0 { + return f.tensors.values[index] + } + + // fast-forward through key values if we haven't already + _ = f.keyValues.rest() + for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() { + if tensor.Name == name { + return tensor + } + } + + return TensorInfo{} +} + +func (f *File) NumTensors() int { + return int(f.tensors.count) +} + +func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] { + // fast forward through key values if we haven't already + f.keyValues.rest() + return f.tensors.All() +} + +func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) { + t := f.TensorInfo(name) + if t.NumBytes() == 0 { + return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name) + } + + // fast forward through tensor info if we haven't already + _ = f.tensors.rest() + return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil +} diff --git a/fs/gguf/gguf_test.go b/fs/gguf/gguf_test.go new file mode 100644 index 00000000..eea28a48 --- /dev/null +++ b/fs/gguf/gguf_test.go @@ -0,0 +1,249 @@ +package gguf_test + +import ( + "bytes" + "os" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs/gguf" +) + +func createBinFile(tb testing.TB) string { + tb.Helper() + f, err := os.CreateTemp(tb.TempDir(), "") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + kv := ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(8), + "llama.embedding_length": uint32(3), + "llama.attention.head_count": uint32(2), + "llama.attention.head_count_kv": uint32(2), + "llama.attention.key_length": uint32(3), + "llama.rope.dimension_count": uint32(4), + "llama.rope.freq_base": float32(10000.0), + "llama.rope.freq_scale": float32(1.0), + "llama.attention.layer_norm_rms_epsilon": float32(1e-6), + "tokenizer.ggml.eos_token_id": uint32(0), + "tokenizer.ggml.eos_token_ids": []int32{1, 2, 3}, + "tokenizer.ggml.tokens": []string{"hello", "world"}, + "tokenizer.ggml.scores": []float32{0, 1}, + } + + tensors := []*ggml.Tensor{ + { + Name: "token_embd.weight", + Kind: 0, + Shape: []uint64{2, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)), + }, + { + Name: "output.weight", + Kind: 0, + Shape: []uint64{3, 2}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)), + }, + } + + for i := range 8 { + tensors = append(tensors, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_q.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_k.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_v.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_output.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }) + } + + if err := ggml.WriteGGUF(f, kv, tensors); err != nil { + tb.Fatal(err) + } + + return f.Name() +} + +func TestRead(t *testing.T) { + f, err := gguf.Open(createBinFile(t)) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if got := f.KeyValue("does.not.exist").Valid(); got { + t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got) + } + + if got := f.KeyValue("general.architecture").String(); got != "llama" { + t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama") + } + + if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" { + t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight") + } else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" { + t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff) + } else if got.Type != gguf.TensorTypeF32 { + t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32) + } + + if got := f.KeyValue("block_count").Uint(); got != 8 { + t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8) + } + + if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" { + t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff) + } + + if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" { + t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff) + } + + var kvs []string + for _, kv := range f.KeyValues() { + if !kv.Valid() { + t.Error("found invalid key-value pair:", kv) + } + + kvs = append(kvs, kv.Key) + } + + if len(kvs) != f.NumKeyValues() { + t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues()) + } + + if diff := cmp.Diff(kvs, []string{ + "general.architecture", + "llama.block_count", + "llama.embedding_length", + "llama.attention.head_count", + "llama.attention.head_count_kv", + "llama.attention.key_length", + "llama.rope.dimension_count", + "llama.rope.freq_base", + "llama.rope.freq_scale", + "llama.attention.layer_norm_rms_epsilon", + "tokenizer.ggml.eos_token_id", + "tokenizer.ggml.eos_token_ids", + "tokenizer.ggml.tokens", + "tokenizer.ggml.scores", + }, cmpopts.SortSlices(strings.Compare)); diff != "" { + t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff) + } + + var tis []string + for _, ti := range f.TensorInfos() { + if !ti.Valid() { + t.Error("found invalid tensor info:", ti) + } + + tis = append(tis, ti.Name) + } + + if len(tis) != f.NumTensors() { + t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors()) + } + + if diff := cmp.Diff(tis, []string{ + "token_embd.weight", + "output.weight", + "blk.0.attn_q.weight", + "blk.0.attn_k.weight", + "blk.0.attn_v.weight", + "blk.0.attn_output.weight", + "blk.1.attn_q.weight", + "blk.1.attn_k.weight", + "blk.1.attn_v.weight", + "blk.1.attn_output.weight", + "blk.2.attn_q.weight", + "blk.2.attn_k.weight", + "blk.2.attn_v.weight", + "blk.2.attn_output.weight", + "blk.3.attn_q.weight", + "blk.3.attn_k.weight", + "blk.3.attn_v.weight", + "blk.3.attn_output.weight", + "blk.4.attn_q.weight", + "blk.4.attn_k.weight", + "blk.4.attn_v.weight", + "blk.4.attn_output.weight", + "blk.5.attn_q.weight", + "blk.5.attn_k.weight", + "blk.5.attn_v.weight", + "blk.5.attn_output.weight", + "blk.6.attn_q.weight", + "blk.6.attn_k.weight", + "blk.6.attn_v.weight", + "blk.6.attn_output.weight", + "blk.7.attn_q.weight", + "blk.7.attn_k.weight", + "blk.7.attn_v.weight", + "blk.7.attn_output.weight", + }, cmpopts.SortSlices(strings.Compare)); diff != "" { + t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff) + } + + ti, r, err := f.TensorReader("output.weight") + if err != nil { + t.Fatalf(`TensorReader("output.weight") error: %v`, err) + } + + if ti.Name != "output.weight" { + t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight") + } else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" { + t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff) + } else if ti.Type != gguf.TensorTypeF32 { + t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32) + } + + var b bytes.Buffer + if _, err := b.ReadFrom(r); err != nil { + t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err) + } + + if b.Len() != int(ti.NumBytes()) { + t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes()) + } +} + +func BenchmarkRead(b *testing.B) { + b.ReportAllocs() + + p := createBinFile(b) + for b.Loop() { + f, err := gguf.Open(p) + if err != nil { + b.Fatal(err) + } + + if got := f.KeyValue("general.architecture").String(); got != "llama" { + b.Errorf("got = %q, want %q", got, "llama") + } + + // Iterate through some tensors + for range f.TensorInfos() { + } + + f.Close() + } +} diff --git a/fs/gguf/keyvalue.go b/fs/gguf/keyvalue.go new file mode 100644 index 00000000..5843326c --- /dev/null +++ b/fs/gguf/keyvalue.go @@ -0,0 +1,90 @@ +package gguf + +import ( + "reflect" + "slices" +) + +type KeyValue struct { + Key string + Value +} + +func (kv KeyValue) Valid() bool { + return kv.Key != "" && kv.Value.value != nil +} + +type Value struct { + value any +} + +func value[T any](v Value, kinds ...reflect.Kind) (t T) { + vv := reflect.ValueOf(v.value) + if slices.Contains(kinds, vv.Kind()) { + t = vv.Convert(reflect.TypeOf(t)).Interface().(T) + } + return +} + +func values[T any](v Value, kinds ...reflect.Kind) (ts []T) { + switch vv := reflect.ValueOf(v.value); vv.Kind() { + case reflect.Slice: + if slices.Contains(kinds, vv.Type().Elem().Kind()) { + ts = make([]T, vv.Len()) + for i := range vv.Len() { + ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T) + } + } + } + return +} + +// Int returns Value as a signed integer. If it is not a signed integer, it returns 0. +func (v Value) Int() int64 { + return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) +} + +// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil. +func (v Value) Ints() (i64s []int64) { + return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) +} + +// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0. +func (v Value) Uint() uint64 { + return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) +} + +// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil. +func (v Value) Uints() (u64s []uint64) { + return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) +} + +// Float returns Value as a float. If it is not a float, it returns 0. +func (v Value) Float() float64 { + return value[float64](v, reflect.Float32, reflect.Float64) +} + +// Floats returns Value as a float slice. If it is not a float slice, it returns nil. +func (v Value) Floats() (f64s []float64) { + return values[float64](v, reflect.Float32, reflect.Float64) +} + +// Bool returns Value as a boolean. If it is not a boolean, it returns false. +func (v Value) Bool() bool { + return value[bool](v, reflect.Bool) +} + +// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil. +func (v Value) Bools() (bools []bool) { + return values[bool](v, reflect.Bool) +} + +// String returns Value as a string. If it is not a string, it returns an empty string. +func (v Value) String() string { + return value[string](v, reflect.String) +} + +// Strings returns Value as a string slice. If it is not a string slice, it returns nil. +func (v Value) Strings() (strings []string) { + return values[string](v, reflect.String) +} diff --git a/fs/gguf/keyvalue_test.go b/fs/gguf/keyvalue_test.go new file mode 100644 index 00000000..2caacc53 --- /dev/null +++ b/fs/gguf/keyvalue_test.go @@ -0,0 +1,208 @@ +package gguf + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func split(name string, values map[string][]any) (matched []any, unmatched []any) { + for key, value := range values { + if key == name { + matched = value + } else { + unmatched = append(unmatched, value...) + } + } + return +} + +func TestValue(t *testing.T) { + values := map[string][]any{ + "int64": {int(42), int8(42), int16(42), int32(42), int64(42)}, + "uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)}, + "float64": {float32(42), float64(42)}, + "string": {"42", "hello"}, + "bool": {true, false}, + } + + t.Run("int64", func(t *testing.T) { + matched, unmatched := split("int64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if i64 := kv.Int(); i64 != 42 { + t.Errorf("expected 42, got %d", i64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if i64 := kv.Int(); i64 != 0 { + t.Errorf("expected 42, got %d", i64) + } + } + }) + + t.Run("uint64", func(t *testing.T) { + matched, unmatched := split("uint64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if u64 := kv.Uint(); u64 != 42 { + t.Errorf("expected 42, got %d", u64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if u64 := kv.Uint(); u64 != 0 { + t.Errorf("expected 42, got %d", u64) + } + } + }) + + t.Run("float64", func(t *testing.T) { + matched, unmatched := split("float64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if f64 := kv.Float(); f64 != 42 { + t.Errorf("expected 42, got %f", f64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if f64 := kv.Float(); f64 != 0 { + t.Errorf("expected 42, got %f", f64) + } + } + }) + + t.Run("string", func(t *testing.T) { + matched, unmatched := split("string", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if s := kv.String(); s != v { + t.Errorf("expected 42, got %s", s) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if s := kv.String(); s != "" { + t.Errorf("expected 42, got %s", s) + } + } + }) + + t.Run("bool", func(t *testing.T) { + matched, unmatched := split("bool", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bool(); b != v { + t.Errorf("expected true, got %v", b) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bool(); b != false { + t.Errorf("expected false, got %v", b) + } + } + }) +} + +func TestValues(t *testing.T) { + values := map[string][]any{ + "int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}}, + "uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}}, + "float64s": {[]float32{42}, []float64{42}}, + "strings": {[]string{"42"}, []string{"hello"}}, + "bools": {[]bool{true}, []bool{false}}, + } + + t.Run("int64s", func(t *testing.T) { + matched, unmatched := split("int64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if i64s := kv.Ints(); i64s != nil { + t.Errorf("expected nil, got %v", i64s) + } + } + }) + + t.Run("uint64s", func(t *testing.T) { + matched, unmatched := split("uint64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if u64s := kv.Uints(); u64s != nil { + t.Errorf("expected nil, got %v", u64s) + } + } + }) + + t.Run("float64s", func(t *testing.T) { + matched, unmatched := split("float64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if f64s := kv.Floats(); f64s != nil { + t.Errorf("expected nil, got %v", f64s) + } + } + }) + + t.Run("strings", func(t *testing.T) { + matched, unmatched := split("strings", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Strings(), v); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if s := kv.Strings(); s != nil { + t.Errorf("expected nil, got %v", s) + } + } + }) + + t.Run("bools", func(t *testing.T) { + matched, unmatched := split("bools", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Bools(), v); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bools(); b != nil { + t.Errorf("expected nil, got %v", b) + } + } + }) +} diff --git a/fs/gguf/lazy.go b/fs/gguf/lazy.go new file mode 100644 index 00000000..16ab9909 --- /dev/null +++ b/fs/gguf/lazy.go @@ -0,0 +1,89 @@ +package gguf + +import ( + "encoding/binary" + "iter" + "log/slog" +) + +type lazy[T any] struct { + count uint64 + next func() (T, bool) + stop func() + values []T + + // successFunc is called when all values have been successfully read. + successFunc func() error +} + +func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) { + it := lazy[T]{} + if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil { + return nil, err + } + + it.values = make([]T, 0) + it.next, it.stop = iter.Pull(func(yield func(T) bool) { + for i := range it.count { + t, err := fn() + if err != nil { + slog.Error("error reading tensor", "index", i, "error", err) + return + } + + it.values = append(it.values, t) + if !yield(t) { + break + } + } + + if it.successFunc != nil { + it.successFunc() + } + }) + + return &it, nil +} + +func (g *lazy[T]) Values() iter.Seq[T] { + return func(yield func(T) bool) { + for _, v := range g.All() { + if !yield(v) { + break + } + } + } +} + +func (g *lazy[T]) All() iter.Seq2[int, T] { + return func(yield func(int, T) bool) { + for i := range int(g.count) { + if i < len(g.values) { + if !yield(i, g.values[i]) { + break + } + } else { + t, ok := g.next() + if !ok { + break + } + + if !yield(i, t) { + break + } + } + } + } +} + +func (g *lazy[T]) rest() (collected bool) { + for { + _, ok := g.next() + collected = collected || ok + if !ok { + break + } + } + + return collected +} diff --git a/fs/gguf/reader.go b/fs/gguf/reader.go new file mode 100644 index 00000000..0bd76184 --- /dev/null +++ b/fs/gguf/reader.go @@ -0,0 +1,23 @@ +package gguf + +import ( + "bufio" + "io" +) + +type bufferedReader struct { + offset int64 + *bufio.Reader +} + +func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader { + return &bufferedReader{ + Reader: bufio.NewReaderSize(rs, size), + } +} + +func (rs *bufferedReader) Read(p []byte) (n int, err error) { + n, err = rs.Reader.Read(p) + rs.offset += int64(n) + return n, err +} diff --git a/fs/gguf/tensor.go b/fs/gguf/tensor.go new file mode 100644 index 00000000..194c1d73 --- /dev/null +++ b/fs/gguf/tensor.go @@ -0,0 +1,288 @@ +package gguf + +import ( + "log/slog" + "strings" +) + +type TensorInfo struct { + Name string + Offset uint64 + Shape []uint64 + Type TensorType +} + +func (ti TensorInfo) Valid() bool { + return ti.Name != "" && ti.NumBytes() > 0 +} + +func (ti TensorInfo) NumValues() int64 { + var numItems int64 = 1 + for _, dim := range ti.Shape { + numItems *= int64(dim) + } + return numItems +} + +// NumBytes returns the number of bytes in the tensor. +func (ti TensorInfo) NumBytes() int64 { + return int64(float64(ti.NumValues()) * ti.Type.NumBytes()) +} + +func (ti TensorInfo) LogValue() slog.Value { + return slog.GroupValue( + slog.String("name", ti.Name), + slog.Int64("offset", int64(ti.Offset)), + slog.Any("shape", ti.Shape), + slog.Int64("num_values", ti.NumValues()), + slog.Int64("num_bytes", ti.NumBytes()), + slog.Any("type", ti.Type), + ) +} + +type TensorType uint32 + +const ( + TensorTypeF32 TensorType = iota + TensorTypeF16 + TensorTypeQ4_0 + TensorTypeQ4_1 + + // unexported // unused in gguf + tensorTypeQ4_2 + tensorTypeQ4_3 + + TensorTypeQ5_0 + TensorTypeQ5_1 + TensorTypeQ8_0 + TensorTypeQ8_1 + TensorTypeQ2_K + TensorTypeQ3_K + TensorTypeQ4_K + TensorTypeQ5_K + TensorTypeQ6_K + TensorTypeQ8_K + + // unexported // unquantizable by ollama + tensorTypeIQ2_XXS + tensorTypeIQ2_XS + tensorTypeIQ3_XXS + tensorTypeIQ1_S + tensorTypeIQ4_NL + tensorTypeIQ3_S + tensorTypeIQ2_S + tensorTypeIQ4_XS + + TensorTypeI8 + TensorTypeI16 + TensorTypeI32 + TensorTypeI64 + TensorTypeF64 + + // unexported // unquantizable by ollama + tensorTypeIQ1_M + + TensorTypeBF16 + + // unexported // unused in gguf + tensorTypeQ4_0_4_4 + tensorTypeQ4_0_4_8 + tensorTypeQ4_0_8_8 + + // unexported // unquantizable by ollama + tensorTypeTQ1_0 + tensorTypeTQ2_0 + + // unexported // unused in gguf + tensorTypeIQ4_NL_4_4 + tensorTypeIQ4_NL_4_8 + tensorTypeIQ4_NL_8_8 +) + +func (tt TensorType) NumBytes() float64 { + return float64(tt.typeSize()) / float64(tt.blockSize()) +} + +func (tt TensorType) typeSize() int64 { + switch tt { + case TensorTypeF32: + return 4 + case TensorTypeF16: + return 2 + case TensorTypeQ4_0: + return 2 + tt.blockSize()/2 + case TensorTypeQ4_1: + return 2 + 2 + tt.blockSize()/2 + case TensorTypeQ5_0: + return 2 + 4 + tt.blockSize()/2 + case TensorTypeQ5_1: + return 2 + 2 + 4 + tt.blockSize()/2 + case TensorTypeQ8_0: + return 2 + tt.blockSize() + case TensorTypeQ8_1: + return 2 + 2 + tt.blockSize() + case TensorTypeQ2_K: + return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2 + case TensorTypeQ3_K: + return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2 + case TensorTypeQ4_K: + return 2 + 2 + 12 + tt.blockSize()/2 + case TensorTypeQ5_K: + return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2 + case TensorTypeQ6_K: + return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2 + case TensorTypeQ8_K: + return 4 + tt.blockSize() + 2*tt.blockSize()/16 + case tensorTypeIQ2_XXS: + return 2 + 2*tt.blockSize()/8 + case tensorTypeIQ2_XS: + return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32 + case tensorTypeIQ3_XXS: + return 2 + tt.blockSize()/4 + tt.blockSize()/8 + case tensorTypeIQ1_S: + return 2 + tt.blockSize()/8 + tt.blockSize()/16 + case tensorTypeIQ4_NL: + return 2 + tt.blockSize()/2 + case tensorTypeIQ3_S: + return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4 + case tensorTypeIQ2_S: + return 2 + tt.blockSize()/4 + tt.blockSize()/16 + case tensorTypeIQ4_XS: + return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64 + case TensorTypeI8: + return 1 + case TensorTypeI16: + return 2 + case TensorTypeI32: + return 4 + case TensorTypeI64: + return 8 + case TensorTypeF64: + return 8 + case tensorTypeIQ1_M: + return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32 + case TensorTypeBF16: + return 2 + default: + return 0 + } +} + +func (tt TensorType) blockSize() int64 { + switch tt { + case TensorTypeF32, + TensorTypeF16, + TensorTypeI8, + TensorTypeI16, + TensorTypeI32, + TensorTypeI64, + TensorTypeF64, + TensorTypeBF16: + return 1 + case TensorTypeQ4_0, + TensorTypeQ4_1, + TensorTypeQ5_0, + TensorTypeQ5_1, + TensorTypeQ8_0, + TensorTypeQ8_1, + tensorTypeIQ4_NL: + return 32 + default: + return 256 + } +} + +func (tt TensorType) String() string { + switch tt { + case TensorTypeF32: + return "f32" + case TensorTypeF16: + return "f16" + case TensorTypeQ4_0: + return "q4_0" + case TensorTypeQ4_1: + return "q4_1" + case tensorTypeQ4_2: + return "q4_2" + case tensorTypeQ4_3: + return "q4_3" + case TensorTypeQ5_0: + return "q5_0" + case TensorTypeQ5_1: + return "q5_1" + case TensorTypeQ8_0: + return "q8_0" + case TensorTypeQ8_1: + return "q8_1" + case TensorTypeQ2_K: + return "q2_k" + case TensorTypeQ3_K: + return "q3_k" + case TensorTypeQ4_K: + return "q4_k" + case TensorTypeQ5_K: + return "q5_k" + case TensorTypeQ6_K: + return "q6_k" + case TensorTypeQ8_K: + return "q8_k" + case tensorTypeIQ2_XXS: + return "iq2_xxs" + case tensorTypeIQ2_XS: + return "iq2_xs" + case tensorTypeIQ3_XXS: + return "iq3_xxs" + case tensorTypeIQ1_S: + return "iq1_s" + case tensorTypeIQ4_NL: + return "iq4_nl" + case tensorTypeIQ3_S: + return "iq3_s" + case tensorTypeIQ2_S: + return "iq2_s" + case tensorTypeIQ4_XS: + return "iq4_xs" + case TensorTypeI8: + return "i8" + case TensorTypeI16: + return "i16" + case TensorTypeI32: + return "i32" + case TensorTypeI64: + return "i64" + case TensorTypeF64: + return "f64" + case tensorTypeIQ1_M: + return "iq1_m" + case TensorTypeBF16: + return "bf16" + case tensorTypeQ4_0_4_4: + return "q4_0_4_4" + case tensorTypeQ4_0_4_8: + return "q4_0_4_8" + case tensorTypeQ4_0_8_8: + return "q4_0_8_8" + case tensorTypeTQ1_0: + return "tq1_0" + case tensorTypeTQ2_0: + return "tq2_0" + case tensorTypeIQ4_NL_4_4: + return "iq4_nl_4_4" + case tensorTypeIQ4_NL_4_8: + return "iq4_nl_4_8" + case tensorTypeIQ4_NL_8_8: + return "iq4_nl_8_8" + default: + return "unknown" + } +} + +func (tt TensorType) LogValue() slog.Value { + return slog.GroupValue( + slog.Uint64("value", uint64(tt)), + slog.String("name", strings.ToUpper(tt.String())), + slog.Int64("size", tt.typeSize()), + slog.Int64("block_size", tt.blockSize()), + slog.Float64("num_bytes", tt.NumBytes()), + ) +} diff --git a/go.mod b/go.mod index 283286b7..6de5959b 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 github.com/emirpasic/gods/v2 v2.0.0-alpha - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c diff --git a/go.sum b/go.sum index 5755616f..c0ab53aa 100644 --- a/go.sum +++ b/go.sum @@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/server/images.go b/server/images.go index d6cceff4..38505cc5 100644 --- a/server/images.go +++ b/server/images.go @@ -23,7 +23,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs/gguf" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -73,22 +73,18 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - r, err := os.Open(m.ModelPath) + f, err := gguf.Open(m.ModelPath) if err == nil { - defer r.Close() + defer f.Close() - f, err := ggml.Decode(r, 1024) - if err == nil { - if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { - capabilities = append(capabilities, model.CapabilityEmbedding) - } else { - capabilities = append(capabilities, model.CapabilityCompletion) - } - if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok { - capabilities = append(capabilities, model.CapabilityVision) - } + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) } else { - slog.Error("couldn't decode ggml", "error", err) + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) } } else { slog.Error("couldn't open model file", "error", err) diff --git a/server/images_test.go b/server/images_test.go index 363b298e..a2fba8d9 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -1,123 +1,42 @@ package server import ( - "bytes" - "encoding/binary" - "errors" - "os" - "path/filepath" "strings" "testing" + "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/model" ) -// Constants for GGUF magic bytes and version -var ( - ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF" - ggufVer = uint32(3) // Version 3 -) - -// Helper function to create mock GGUF data -func createMockGGUFData(architecture string, vision bool) []byte { - var buf bytes.Buffer - - // Write GGUF header - buf.Write(ggufMagic) - binary.Write(&buf, binary.LittleEndian, ggufVer) - - // Write tensor count (0 for our test) - var numTensors uint64 = 0 - binary.Write(&buf, binary.LittleEndian, numTensors) - - // Calculate number of metadata entries - numMetaEntries := uint64(1) // architecture entry - if vision { - numMetaEntries++ - } - // Add embedding entry if architecture is "bert" - if architecture == "bert" { - numMetaEntries++ - } - binary.Write(&buf, binary.LittleEndian, numMetaEntries) - - // Write architecture metadata - archKey := "general.architecture" - keyLen := uint64(len(archKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(archKey) - - // String type (8) - var strType uint32 = 8 - binary.Write(&buf, binary.LittleEndian, strType) - - // String length - strLen := uint64(len(architecture)) - binary.Write(&buf, binary.LittleEndian, strLen) - buf.WriteString(architecture) - - if vision { - visionKey := architecture + ".vision.block_count" - keyLen = uint64(len(visionKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(visionKey) - - // uint32 type (4) - var uint32Type uint32 = 4 - binary.Write(&buf, binary.LittleEndian, uint32Type) - - // uint32 value (1) - var countVal uint32 = 1 - binary.Write(&buf, binary.LittleEndian, countVal) - } - // Write embedding metadata if architecture is "bert" - if architecture == "bert" { - poolKey := architecture + ".pooling_type" - keyLen = uint64(len(poolKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(poolKey) - - // uint32 type (4) - var uint32Type uint32 = 4 - binary.Write(&buf, binary.LittleEndian, uint32Type) - - // uint32 value (1) - var poolingVal uint32 = 1 - binary.Write(&buf, binary.LittleEndian, poolingVal) - } - - return buf.Bytes() -} - func TestModelCapabilities(t *testing.T) { - // Create a temporary directory for test files - tempDir := t.TempDir() + // Create completion model (llama architecture without vision) + completionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + }, []*ggml.Tensor{}) - // Create different types of mock model files - completionModelPath := filepath.Join(tempDir, "model.bin") - visionModelPath := filepath.Join(tempDir, "vision_model.bin") - embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") - // Create a simple model file for tests that don't depend on GGUF content - simpleModelPath := filepath.Join(tempDir, "simple_model.bin") + // Create vision model (llama architecture with vision block count) + visionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.vision.block_count": uint32(1), + }, []*ggml.Tensor{}) - if err := errors.Join( - os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644), - os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), - os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644), - os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), - ); err != nil { - t.Fatalf("Failed to create model files: %v", err) - } + // Create embedding model (bert architecture with pooling type) + embeddingModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(1), + }, []*ggml.Tensor{}) toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + chatTemplate, err := template.Parse("{{ .prompt }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) @@ -145,21 +64,13 @@ func TestModelCapabilities(t *testing.T) { }, expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, }, - { - name: "model with tools and insert capability", - model: Model{ - ModelPath: simpleModelPath, - Template: toolsInsertTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, - }, { name: "model with tools capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, - expectedCaps: []model.Capability{model.CapabilityTools}, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools}, }, { name: "model with vision capability", @@ -224,29 +135,33 @@ func TestModelCapabilities(t *testing.T) { } func TestModelCheckCapabilities(t *testing.T) { - // Create a temporary directory for test files - tempDir := t.TempDir() + // Create simple model file for tests that don't depend on GGUF content + completionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + }, []*ggml.Tensor{}) - visionModelPath := filepath.Join(tempDir, "vision_model.bin") - simpleModelPath := filepath.Join(tempDir, "model.bin") - embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + // Create vision model (llama architecture with vision block count) + visionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.vision.block_count": uint32(1), + }, []*ggml.Tensor{}) - if err := errors.Join( - os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), - os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), - os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644), - ); err != nil { - t.Fatalf("Failed to create model files: %v", err) - } + // Create embedding model (bert architecture with pooling type) + embeddingModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(1), + }, []*ggml.Tensor{}) toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + chatTemplate, err := template.Parse("{{ .prompt }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) @@ -261,7 +176,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "completion model without tools capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: chatTemplate, }, checkCaps: []model.Capability{model.CapabilityTools}, @@ -270,7 +185,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model with all needed capabilities", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsInsertTemplate, }, checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, @@ -278,7 +193,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model missing insert capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, checkCaps: []model.Capability{model.CapabilityInsert}, @@ -287,7 +202,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model missing vision capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, checkCaps: []model.Capability{model.CapabilityVision}, @@ -312,7 +227,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "unknown capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: chatTemplate, }, checkCaps: []model.Capability{"unknown"}, diff --git a/server/quantization_test.go b/server/quantization_test.go index 4f717c2c..8b726c83 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -257,16 +257,8 @@ func TestQuantizeModel(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - f, err := os.CreateTemp(t.TempDir(), tt.name) - if err != nil { - t.Fatal(err.Error()) - } - defer f.Close() - err = fsggml.WriteGGUF(f, tt.kv, tt.tensors) - if err != nil { - t.Fatalf("failed to create initial model: %s", err) - } - fp, err := os.Open(f.Name()) + p, _ := createBinFile(t, tt.kv, tt.tensors) + fp, err := os.Open(p) if err != nil { t.Fatal(err.Error()) } diff --git a/server/sched_test.go b/server/sched_test.go index 01fb9a70..3892fbba 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -112,11 +112,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() - f, err := os.CreateTemp(t.TempDir(), modelName) - require.NoError(t, err) - defer f.Close() - - require.NoError(t, ggml.WriteGGUF(f, ggml.KV{ + p, _ := createBinFile(t, ggml.KV{ "general.architecture": "llama", "llama.context_length": uint32(32), "llama.embedding_length": uint32(4096), @@ -129,14 +125,14 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est }, []*ggml.Tensor{ {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, - })) - require.NoError(t, err) - - fname := f.Name() - model := &Model{Name: modelName, ModelPath: fname} - b.f, err = llm.LoadModel(model.ModelPath, 0) - require.NoError(t, err) + }) + model := &Model{Name: modelName, ModelPath: p} + f, err := llm.LoadModel(model.ModelPath, 0) + if err != nil { + t.Fatal(err) + } + b.f = f if duration == nil { duration = &api.Duration{Duration: 5 * time.Millisecond} } From 9f8a18ec050ef67fca11d4f9bea0508eece93a68 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 12 Jun 2025 14:18:54 -0700 Subject: [PATCH 22/26] tools: loosen tool parsing to allow for more formats (#11030) --- server/routes.go | 8 +- tools/template.go | 156 +++ tools/template_test.go | 139 +++ tools/testdata/command-r-plus.gotmpl | 67 -- tools/testdata/command-r-plus.out | 39 - tools/testdata/firefunction.gotmpl | 31 - tools/testdata/firefunction.out | 17 - tools/testdata/llama3-groq-tool-use.gotmpl | 43 - tools/testdata/llama3-groq-tool-use.out | 24 - tools/testdata/llama3.2.gotmpl | 44 - tools/testdata/llama3.2.out | 24 - tools/testdata/messages.json | 39 - tools/testdata/mistral.gotmpl | 15 - tools/testdata/mistral.out | 3 - tools/testdata/nemotron.gotmpl | 33 - tools/testdata/nemotron.out | 18 - tools/testdata/qwen2.5.gotmpl | 51 - tools/testdata/qwen2.5.out | 31 - tools/testdata/qwen3.gotmpl | 50 - tools/testdata/qwen3.out | 31 - tools/testdata/tools.json | 30 - tools/testdata/xlam.gotmpl | 45 - tools/testdata/xlam.out | 40 - tools/tools.go | 470 ++++---- tools/tools_test.go | 1246 +++++++++++--------- tools/tools_utils.go | 222 ---- tools/tools_utils_test.go | 497 -------- 27 files changed, 1238 insertions(+), 2175 deletions(-) create mode 100644 tools/template.go create mode 100644 tools/template_test.go delete mode 100644 tools/testdata/command-r-plus.gotmpl delete mode 100644 tools/testdata/command-r-plus.out delete mode 100644 tools/testdata/firefunction.gotmpl delete mode 100644 tools/testdata/firefunction.out delete mode 100644 tools/testdata/llama3-groq-tool-use.gotmpl delete mode 100644 tools/testdata/llama3-groq-tool-use.out delete mode 100644 tools/testdata/llama3.2.gotmpl delete mode 100644 tools/testdata/llama3.2.out delete mode 100644 tools/testdata/messages.json delete mode 100644 tools/testdata/mistral.gotmpl delete mode 100644 tools/testdata/mistral.out delete mode 100644 tools/testdata/nemotron.gotmpl delete mode 100644 tools/testdata/nemotron.out delete mode 100644 tools/testdata/qwen2.5.gotmpl delete mode 100644 tools/testdata/qwen2.5.out delete mode 100644 tools/testdata/qwen3.gotmpl delete mode 100644 tools/testdata/qwen3.out delete mode 100644 tools/testdata/tools.json delete mode 100644 tools/testdata/xlam.gotmpl delete mode 100644 tools/testdata/xlam.out delete mode 100644 tools/tools_utils.go delete mode 100644 tools/tools_utils_test.go diff --git a/server/routes.go b/server/routes.go index 8eda5c73..cb46cef1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1526,12 +1526,7 @@ func (s *Server) ChatHandler(c *gin.Context) { var toolParser *tools.Parser if len(req.Tools) > 0 { - toolParser, err = tools.NewParser(m.Template.Template) - if err != nil { - slog.Error("failed to create tool parser", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + toolParser = tools.NewParser(m.Template.Template, req.Tools) } ch := make(chan any) @@ -1584,6 +1579,7 @@ func (s *Server) ChatHandler(c *gin.Context) { // don't return } else { if r.Done { + res.Message.Content = toolParser.Content() ch <- res } return diff --git a/tools/template.go b/tools/template.go new file mode 100644 index 00000000..e22f0675 --- /dev/null +++ b/tools/template.go @@ -0,0 +1,156 @@ +package tools + +import ( + "bytes" + "log/slog" + "slices" + "strings" + "text/template" + "text/template/parse" +) + +// parseTag finds the tool calling tag from a Go template +// often [TOOL_CALL] or similar by finding the +// first text node after .ToolCalls and returning the content +// if no tag is found, return "{" to indicate that json objects +// should be attempted to be parsed as tool calls +func parseTag(tmpl *template.Template) string { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("template or tree is nil") + return "{" + } + + tc := findToolCallNode(tmpl.Tree.Root.Nodes) + if tc == nil { + return "{" + } + + tn := findTextNode(tc.List.Nodes) + if tn == nil { + return "{" + } + + tag := string(tn.Text) + tag = strings.ReplaceAll(tag, "\r\n", "\n") + + // avoid parsing { onwards as this may be a tool call + // however keep '{' as a prefix if there is no tag + // so that all json objects will be attempted to + // be parsed as tool calls + tag, _, _ = strings.Cut(tag, "{") + tag = strings.TrimSpace(tag) + if tag == "" { + tag = "{" + } + + return tag +} + +// findToolCallNode searches for and returns an IfNode with .ToolCalls +func findToolCallNode(nodes []parse.Node) *parse.IfNode { + isToolCallsNode := func(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false + } + + for _, node := range nodes { + switch n := node.(type) { + case *parse.IfNode: + if isToolCallsNode(n) { + return n + } + // Recursively search in nested IfNodes + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + case *parse.ListNode: + if result := findToolCallNode(n.Nodes); result != nil { + return result + } + case *parse.RangeNode: + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + case *parse.WithNode: + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + } + } + return nil +} + +// findTextNode does a depth-first search for the first text content in nodes, +// stopping at template constructs to avoid parsing text after the tool calls +func findTextNode(nodes []parse.Node) *parse.TextNode { + for _, node := range nodes { + switch n := node.(type) { + case *parse.TextNode: + // skip whitespace-only text nodes + if len(bytes.TrimSpace(n.Text)) == 0 { + continue + } + return n + case *parse.IfNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.ListNode: + if text := findTextNode(n.Nodes); text != nil { + return text + } + case *parse.RangeNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.WithNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.ActionNode: + return nil + } + } + return nil +} diff --git a/tools/template_test.go b/tools/template_test.go new file mode 100644 index 00000000..970c0d59 --- /dev/null +++ b/tools/template_test.go @@ -0,0 +1,139 @@ +package tools + +import ( + "testing" + "text/template" +) + +func TestParseTag(t *testing.T) { + cases := []struct { + name string + template string + want string + }{ + { + name: "empty", + template: "", + want: "{", + }, + { + name: "no tag", + template: "{{if .ToolCalls}}{{end}}", + want: "{", + }, + { + name: "no tag with range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}{{end}}", + want: "{", + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json", + }, + { + name: "square brackets", + template: "{{if .ToolCalls}}[{{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "[", + }, + { + name: "square brackets with whitespace", + template: "{{if .ToolCalls}}\n [ {{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "[", + }, + { + name: "tailing ]", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "{", + }, + { + name: "whitespace only", + template: "{{if .ToolCalls}} {{range .ToolCalls}}{{ . }}{{end}}{{end}}", + want: "{", + }, + { + name: "whitespace only in range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{{ . }}\n{{end}}{{end}}", + want: "{", + }, + { + name: "json objects", + template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`, + want: "{", + }, + { + name: "json objects with whitespace", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "{", + }, + { + name: "json objects with CRLF", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\r\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "{", + }, + { + name: "json objects with whitespace before and after range", + template: "{{if .ToolCalls}}\n{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\r\n{{end}}\r\n{{end}}", + want: "{", + }, + { + name: "before and after range", + template: "{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>functionget_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>\n{{end}}<|tool▁calls▁end|>{{end}}", + want: "<|tool▁calls▁begin|>", + }, + { + name: "after range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "", + }, + { + name: "after range with leading whitespace before range", + template: "{{if .ToolCalls}}\n{{range .ToolCalls}}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "", + }, + { + name: "tool call in range with {", + template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`, + want: "", + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + }, + { + name: "action tag", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action: ```json", + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools[", + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL] [", + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL][", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := template.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got := parseTag(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + }) + } +} diff --git a/tools/testdata/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl deleted file mode 100644 index f30124e3..00000000 --- a/tools/testdata/command-r-plus.gotmpl +++ /dev/null @@ -1,67 +0,0 @@ -{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> -{{- if .Tools }}# Safety Preamble -The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. - -# System Preamble -## Basic Rules -You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. - -{{ if .System }}# User Preamble -{{ .System }} -{{- end }} - -## Available Tools -Here is a list of tools that you have available to you: -{{- range .Tools }} - -```python -def {{ .Function.Name }}( -{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]: - """{{ .Function.Description }} - -{{- if .Function.Parameters.Properties }} - - Args: -{{- range $name, $property := .Function.Parameters.Properties }} - {{ $name }} ({{ $property.Type }}): {{ $property.Description }} -{{- end }} -{{- end }} - """ - pass -``` -{{- end }} -{{- else if .System }}{{ .System }} -{{- end }}<|END_OF_TURN_TOKEN|> -{{- end }} -{{- range .Messages }} -{{- if eq .Role "system" }} -{{- continue }} -{{- end }}<|START_OF_TURN_TOKEN|> -{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }} -{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|> -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -Action: ```json -[ -{{- range .ToolCalls }} - { - "tool_name": "{{ .Function.Name }}", - "parameters": {{ .Function.Arguments }} - } -{{- end }} -]``` -{{ continue }} -{{ end }} -{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|> -{{ .Content }} -{{- end }}<|END_OF_TURN_TOKEN|> -{{- end }} -{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: -```json -[ - { - "tool_name": title of the tool in the specification, - "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters - } -]``` -{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/tools/testdata/command-r-plus.out b/tools/testdata/command-r-plus.out deleted file mode 100644 index 8193d40c..00000000 --- a/tools/testdata/command-r-plus.out +++ /dev/null @@ -1,39 +0,0 @@ -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble -The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. - -# System Preamble -## Basic Rules -You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. - -# User Preamble -You are a knowledgeable assistant. You can answer questions and perform tasks. - -## Available Tools -Here is a list of tools that you have available to you: - -```python -def get_current_weather(format: string, location: string, ) -> List[Dict]: - """Get the current weather - - Args: - format (string): The temperature unit to use. Infer this from the user's location. - location (string): The city and state, e.g. San Francisco, CA - """ - pass -```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> -Action: ```json -[ - { - "tool_name": "get_current_weather", - "parameters": {"format":"celsius","location":"Paris, France"} - } -]``` -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> -22<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: -```json -[ - { - "tool_name": title of the tool in the specification, - "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters - } -]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/tools/testdata/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl deleted file mode 100644 index 312be205..00000000 --- a/tools/testdata/firefunction.gotmpl +++ /dev/null @@ -1,31 +0,0 @@ -{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|> -{{- if .System }} -{{ .System }} -{{- end }} -In addition to plain text responses, you can chose to call one or more of the provided functions. - -Use the following rule to decide when to call a function: - * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so - * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls - -If you decide to call functions: - * prefix function calls with functools marker (no closing marker required) - * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] - * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 - * make sure you pick the right functions that match the user intent - -Available functions as JSON spec: -{{- if .Tools }} -{{ .Tools }} -{{- end }}<|eot_id|> -{{- end }} -{{- range .Messages }}<|start_header_id|> -{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }} -{{- end }}<|end_header_id|> -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} functools[ -{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }} -{{- end }}] -{{- end }}<|eot_id|> -{{- end }}<|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/tools/testdata/firefunction.out b/tools/testdata/firefunction.out deleted file mode 100644 index 144f5e42..00000000 --- a/tools/testdata/firefunction.out +++ /dev/null @@ -1,17 +0,0 @@ -<|start_header_id|>system<|end_header_id|> -You are a knowledgeable assistant. You can answer questions and perform tasks. -In addition to plain text responses, you can chose to call one or more of the provided functions. - -Use the following rule to decide when to call a function: - * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so - * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls - -If you decide to call functions: - * prefix function calls with functools marker (no closing marker required) - * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] - * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 - * make sure you pick the right functions that match the user intent - -Available functions as JSON spec: -[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgeable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/tools/testdata/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl deleted file mode 100644 index 45e9b462..00000000 --- a/tools/testdata/llama3-groq-tool-use.gotmpl +++ /dev/null @@ -1,43 +0,0 @@ -{{- if .Messages }} -{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|> - -{{ .System }} -{{- if .Tools }} You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within XML tags as follows: - -{"name": ,"arguments": } - - -Here are the available tools: - -{{- range .Tools }} {{ .Function }} -{{- end }} -{{- end }} -{{- end }}<|eot_id|> -{{- range .Messages }} -{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|> - -{{ if eq .Role "user" }}{{ .Content }} -{{- else if eq .Role "assistant" }} -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{- end }} - -{{- end }} -{{- else if eq .Role "tool" }} -{{ .Content }} - -{{- end }}<|eot_id|> -{{- end }} -{{- end }}<|start_header_id|>assistant<|end_header_id|> - -{{ else }} -{{ if .System }}<|start_header_id|>system<|end_header_id|> - -{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> - -{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }}{{ .Response }} -{{- if .Response }}<|eot_id|> -{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out deleted file mode 100644 index 912ad11c..00000000 --- a/tools/testdata/llama3-groq-tool-use.out +++ /dev/null @@ -1,24 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -You are a knowledgeable assistant. You can answer questions and perform tasks. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within XML tags as follows: - -{"name": ,"arguments": } - - -Here are the available tools: - {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} <|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|eot_id|><|start_header_id|>tool<|end_header_id|> - - -22 -<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - diff --git a/tools/testdata/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl deleted file mode 100644 index b132423e..00000000 --- a/tools/testdata/llama3.2.gotmpl +++ /dev/null @@ -1,44 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -Cutting Knowledge Date: December 2023 - -{{ if .System }}{{ .System }} -{{- end }} -{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. - -You are a helpful assistant with tool calling capabilities. -{{- end }}<|eot_id|> -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 }} -{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> -{{- if and $.Tools $last }} - -Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. - -Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. - -{{ range $.Tools }} -{{- . }} -{{ end }} -{{ .Content }}<|eot_id|> -{{- else }} - -{{ .Content }}<|eot_id|> -{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }} -{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> -{{- if .ToolCalls }} -{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} -{{- else }} - -{{ .Content }} -{{- end }}{{ if not $last }}<|eot_id|>{{ end }} -{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> - -{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }} -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3.2.out b/tools/testdata/llama3.2.out deleted file mode 100644 index a27c6eaf..00000000 --- a/tools/testdata/llama3.2.out +++ /dev/null @@ -1,24 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -Cutting Knowledge Date: December 2023 - -You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. - -You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> - -22<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> - -Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. - -Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. - -{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - -What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - diff --git a/tools/testdata/messages.json b/tools/testdata/messages.json deleted file mode 100644 index 42de4711..00000000 --- a/tools/testdata/messages.json +++ /dev/null @@ -1,39 +0,0 @@ -[ - { - "role": "system", - "content": "You are a knowledgeable assistant. You can answer questions and perform tasks." - }, - { - "role": "user", - "content": "What's the weather like today in Paris?" - }, - { - "role": "assistant", - "tool_calls": [ - { - "id": "89a1e453-0bce-4de3-a456-c54bed09c520", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": { - "location": "Paris, France", - "format": "celsius" - } - } - } - ] - }, - { - "role": "tool", - "tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520", - "content": "22" - }, - { - "role": "assistant", - "content": "The current temperature in Paris, France is 22 degrees Celsius." - }, - { - "role": "user", - "content": "What's the weather like today in San Francisco and Toronto?" - } -] diff --git a/tools/testdata/mistral.gotmpl b/tools/testdata/mistral.gotmpl deleted file mode 100644 index b08d6c2c..00000000 --- a/tools/testdata/mistral.gotmpl +++ /dev/null @@ -1,15 +0,0 @@ -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }} -{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS] -{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} - -{{ end }}{{ .Content }}[/INST] -{{- else if eq .Role "assistant" }} -{{- if .Content }} {{ .Content }} -{{- else if .ToolCalls }}[TOOL_CALLS] [ -{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{- end }}] -{{- end }} -{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/mistral.out b/tools/testdata/mistral.out deleted file mode 100644 index 6956e392..00000000 --- a/tools/testdata/mistral.out +++ /dev/null @@ -1,3 +0,0 @@ -[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}][TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgeable assistant. You can answer questions and perform tasks. - -What's the weather like today in San Francisco and Toronto?[/INST] \ No newline at end of file diff --git a/tools/testdata/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl deleted file mode 100644 index 1b6b89ec..00000000 --- a/tools/testdata/nemotron.gotmpl +++ /dev/null @@ -1,33 +0,0 @@ -{{- if (or .Tools .System) }}System -{{ if .System }}{{ .System }} - - -{{ end }} -{{- if .Tools }} -{{- range .Tools }} {{ . }} {{ end }} - - -{{ end }} -{{- end }} -{{- range $i, $m := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}User -{{ .Content }} -{{- if $last }} -Assistant -{{- end }} -{{ else if eq .Role "tool" }}Tool -{{ .Content }} -{{- if $last }} -Assistant -{{- end }} -{{ else if eq .Role "assistant" }}Assistant -{{- if .ToolCalls }} -{{ range .ToolCalls }} {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{ end }} -{{ else }} -{{ .Content }} -{{- if not $last }} -{{ end }} -{{- end }} -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/nemotron.out b/tools/testdata/nemotron.out deleted file mode 100644 index 486889ca..00000000 --- a/tools/testdata/nemotron.out +++ /dev/null @@ -1,18 +0,0 @@ -System -You are a knowledgeable assistant. You can answer questions and perform tasks. - - - {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -User -What's the weather like today in Paris? -Assistant - {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -Tool -22 -Assistant -The current temperature in Paris, France is 22 degrees Celsius. -User -What's the weather like today in San Francisco and Toronto? -Assistant diff --git a/tools/testdata/qwen2.5.gotmpl b/tools/testdata/qwen2.5.gotmpl deleted file mode 100644 index cbd7302c..00000000 --- a/tools/testdata/qwen2.5.gotmpl +++ /dev/null @@ -1,51 +0,0 @@ -{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> -{{- else if .Messages }} -{{- if or .System .Tools }}<|im_start|>system -{{- if .System }} -{{ .System }} -{{- end }} -{{- if .Tools }} - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{{- range .Tools }} -{"type": "function", "function": {{ .Function }}} -{{- end }} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -{{- end }}<|im_end|> -{{ end }} -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}<|im_start|>user -{{ .Content }}<|im_end|> -{{ else if eq .Role "assistant" }}<|im_start|>assistant -{{ if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }} -{{- end }}{{ if not $last }}<|im_end|> -{{ end }} -{{- else if eq .Role "tool" }}<|im_start|>user - -{{ .Content }} -<|im_end|> -{{ end }} -{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant -{{ end }} -{{- end }} -{{- else }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }}{{ if .Prompt }}<|im_start|>user -{{ .Prompt }}<|im_end|> -{{ end }}<|im_start|>assistant -{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen2.5.out b/tools/testdata/qwen2.5.out deleted file mode 100644 index 76bfbfa9..00000000 --- a/tools/testdata/qwen2.5.out +++ /dev/null @@ -1,31 +0,0 @@ -<|im_start|>system -You are a knowledgeable assistant. You can answer questions and perform tasks. - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } -<|im_end|> -<|im_start|>user -What's the weather like today in Paris?<|im_end|> -<|im_start|>assistant - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|im_end|> -<|im_start|>user - -22 -<|im_end|> -<|im_start|>assistant -The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> -<|im_start|>user -What's the weather like today in San Francisco and Toronto?<|im_end|> -<|im_start|>assistant diff --git a/tools/testdata/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl deleted file mode 100644 index 26f6656f..00000000 --- a/tools/testdata/qwen3.gotmpl +++ /dev/null @@ -1,50 +0,0 @@ -{{- if .Messages }} -{{- if or .System .Tools }}<|im_start|>system -{{- if .System }} -{{ .System }} -{{- end }} -{{- if .Tools }} - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{{- range .Tools }} -{"type": "function", "function": {{ .Function }}} -{{- end }} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -{{- end }}<|im_end|> -{{ end }} -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}<|im_start|>user -{{ .Content }}<|im_end|> -{{ else if eq .Role "assistant" }}<|im_start|>assistant -{{ if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }} -{{- end }}{{ if not $last }}<|im_end|> -{{ end }} -{{- else if eq .Role "tool" }}<|im_start|>user - -{{ .Content }} -<|im_end|> -{{ end }} -{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant -{{ end }} -{{- end }} -{{- else }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }}{{ if .Prompt }}<|im_start|>user -{{ .Prompt }}<|im_end|> -{{ end }}<|im_start|>assistant -{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen3.out b/tools/testdata/qwen3.out deleted file mode 100644 index 76bfbfa9..00000000 --- a/tools/testdata/qwen3.out +++ /dev/null @@ -1,31 +0,0 @@ -<|im_start|>system -You are a knowledgeable assistant. You can answer questions and perform tasks. - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } -<|im_end|> -<|im_start|>user -What's the weather like today in Paris?<|im_end|> -<|im_start|>assistant - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|im_end|> -<|im_start|>user - -22 -<|im_end|> -<|im_start|>assistant -The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> -<|im_start|>user -What's the weather like today in San Francisco and Toronto?<|im_end|> -<|im_start|>assistant diff --git a/tools/testdata/tools.json b/tools/testdata/tools.json deleted file mode 100644 index edde4ae0..00000000 --- a/tools/testdata/tools.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ], - "description": "The temperature unit to use. Infer this from the user's location." - } - }, - "required": [ - "location", - "format" - ] - } - } - } -] diff --git a/tools/testdata/xlam.gotmpl b/tools/testdata/xlam.gotmpl deleted file mode 100644 index 51513d69..00000000 --- a/tools/testdata/xlam.gotmpl +++ /dev/null @@ -1,45 +0,0 @@ -{{- if .System }}{{ .System }} -{{ end }} -{{- range $i, $_ := .Messages }} -{{- if eq .Role "user" }}### Instruction: -{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }} -[BEGIN OF TASK INSTRUCTION] -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the functions can be used, point it out and refuse to answer. -If the given question lacks the parameters required by the function, also point it out. -[END OF TASK INSTRUCTION] - -[BEGIN OF AVAILABLE TOOLS] -{{ $.Tools }} -[END OF AVAILABLE TOOLS] - -[BEGIN OF FORMAT INSTRUCTION] -The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. -The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'. -``` -{ - "tool_calls": [ - {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, - ... (more tool calls as required) - ] -} -``` -[END OF FORMAT INSTRUCTION] - -[BEGIN OF QUERY] -{{ .Content }} -[END OF QUERY] - - -{{ else }} -{{ .Content }} -{{ end }} -{{- else if .ToolCalls }}### Response: -{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]} -<|EOT|> -{{ else if eq .Role "assistant" }}### Response: -{{ .Content }} -<|EOT|> -{{ end }} -{{- end }}### Response: \ No newline at end of file diff --git a/tools/testdata/xlam.out b/tools/testdata/xlam.out deleted file mode 100644 index 5d806532..00000000 --- a/tools/testdata/xlam.out +++ /dev/null @@ -1,40 +0,0 @@ -You are a knowledgeable assistant. You can answer questions and perform tasks. -### Instruction: -What's the weather like today in Paris? -### Response: -{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]} -<|EOT|> -### Response: -The current temperature in Paris, France is 22 degrees Celsius. -<|EOT|> -### Instruction: -[BEGIN OF TASK INSTRUCTION] -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the functions can be used, point it out and refuse to answer. -If the given question lacks the parameters required by the function, also point it out. -[END OF TASK INSTRUCTION] - -[BEGIN OF AVAILABLE TOOLS] -[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}] -[END OF AVAILABLE TOOLS] - -[BEGIN OF FORMAT INSTRUCTION] -The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. -The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'. -``` -{ - "tool_calls": [ - {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, - ... (more tool calls as required) - ] -} -``` -[END OF FORMAT INSTRUCTION] - -[BEGIN OF QUERY] -What's the weather like today in San Francisco and Toronto? -[END OF QUERY] - - -### Response: \ No newline at end of file diff --git a/tools/tools.go b/tools/tools.go index 914a5eaf..efeaeee0 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -1,253 +1,287 @@ package tools import ( + "bytes" "encoding/json" - "errors" - "log/slog" "strings" - gotmpl "text/template" + "text/template" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" ) -var ( - errInvalidToolCall = errors.New("invalid tool call format") - errAccumulateMore = errors.New("need to accumulate more content") +type toolsState int + +const ( + toolsState_LookingForTag toolsState = iota + toolsState_ToolCalling + toolsState_Done ) type Parser struct { - greedyParseJSON bool - prefix string - prefixFound bool - tmpl gotmpl.Template - sb strings.Builder - index int - name string - arguments string + tag string + names []string + properties []string + + state toolsState + buffer []byte + n int } -// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// -// Parameters: -// - s: The string to parse -// - name: The field name from template that identifies the tool call name -// - arguments: The field name from template that identifies the tool call arguments -// -// Returns: -// - []api.ToolCall: The parsed tool calls if successful -// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful -func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { - // Check for balanced braces before attempting to parse - braceCount := 0 - squareCount := 0 - startIndex := -1 - var rawToolCalls []string - s = strings.TrimSpace(s) - - // Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. - trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") - for i, c := range s { - switch c { - case '{': - braceCount++ - if startIndex == -1 { - startIndex = i - } - case '}': - braceCount-- - if braceCount == 0 { - rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) - startIndex = -1 - } - case '[': - if trackSquareBrackets { - squareCount++ - } - case ']': - if trackSquareBrackets { - squareCount-- - } - } - - // Negative means we have an extra closing brace/bracket - if braceCount < 0 || squareCount < 0 { - return nil, errInvalidToolCall - } - } - - // If braces/brackets aren't balanced, need more input - if braceCount > 0 || squareCount > 0 { - return nil, errAccumulateMore - } - - t := strings.TrimSpace(s) - if len(t) == 0 { - return nil, errAccumulateMore - } - // If the input is a single square bracket, it's not a valid tool call - if t[0] == '[' && len(t) == 1 { - return nil, errAccumulateMore - } - - // Attempt full unmarshal of the JSON - var toolCalls []api.ToolCall - for _, rawToolCall := range rawToolCalls { - var resp map[string]any - if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { - continue - } - - // Collect nested objects that could contain tool calls - objs := collect(resp) - if len(objs) == 0 { - continue - } - - // Extract tool calls from objects - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } else { - slog.Debug("No valid tool call found in object.", "object", kv) - } - } - } - - // Valid JSON, no tool calls found - if len(toolCalls) == 0 { - slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) - return nil, errInvalidToolCall - } - - return toolCalls, nil +// NewParser creates a new tool call parser from a model's chat +// template and a list of provided tools. +func NewParser(tmpl *template.Template, tools []api.Tool) *Parser { + return NewParserWithTag(tools, parseTag(tmpl)) } -// checkPrefix processes a string to find and handle a prefix pattern. -// -// Returns: -// - The processed string with prefix removed if found -// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful -func (p *Parser) checkPrefix(s string) (string, error) { - if s == "" || p.prefix == "" { - return s, nil +func NewParserWithTag(tools []api.Tool, tag string) *Parser { + var p Parser + for _, t := range tools { + p.names = append(p.names, t.Function.Name) + for r := range t.Function.Parameters.Properties { + p.properties = append(p.properties, r) + } } - - // Check for prefix at start of string - if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { - // Found prefix at start - accumulate for potential tool - p.prefixFound = true - return cut, nil - } - - // Check if prefix overlaps end of string - if idx := suffixOverlap(s, p.prefix); idx != -1 { - // Return everything except overlapping portion - p.sb.Reset() - p.sb.WriteString(s[idx:]) - return s[:idx], errAccumulateMore - } - - // Check if prefix appears in middle of string - if idx := strings.Index(s, p.prefix); idx != -1 { - // Save remainder starting at prefix for next pass - p.sb.Reset() - p.sb.WriteString(strings.TrimSpace(s[idx:])) - // Return everything before prefix - return s[:idx], errAccumulateMore - } - - // No partial prefix found - return s, nil + p.tag = tag + return &p } -// Add processes a string input to parse tool calls and content. -// It handles prefix detection and JSON parsing to extract tool calls. -// -// Returns: -// - tools: Any parsed tool calls -// - content: Non-tool call content -func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { - p.sb.WriteString(s) - s = p.sb.String() - - // Check for prefix pattern in input - s, err := p.checkPrefix(s) - if err != nil { - // Need more input to complete prefix +// Add processes a string input to parse tool calls and content that +// should be sent back to the user. +func (p *Parser) Add(s string) (calls []api.ToolCall, content string) { + if p.state == toolsState_Done { return nil, s } - // Exit if prefix exists in template, greedy parsing is off, and prefix not found - if !p.greedyParseJSON && !p.prefixFound { - p.sb.Reset() - return nil, s + p.buffer = append(p.buffer, s...) + + if p.state == toolsState_LookingForTag { + i, found := p.findTag() + if i == -1 { + content = string(p.buffer) + p.buffer = []byte{} + } else { + content = string(p.buffer[:i]) + p.buffer = p.buffer[i:] + } + + // for models where { or [ are used as tool calling + // tags, we only support parsing tools if the first non- + // whitespace character is { or [ + if p.tag == "{" || p.tag == "[" { + if strings.TrimSpace(content) != "" { + p.state = toolsState_Done + return nil, content + string(p.buffer) + } + } + + if !found { + return nil, content + } + + p.state = toolsState_ToolCalling } - toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) - if err != nil { - if errors.Is(err, errAccumulateMore) { - return nil, "" + for { + call := p.parseToolCall() + if call == nil { + break } - p.sb.Reset() - // Only do greedy JSON parsing if there is no prefix from template - if p.prefix != "" { - p.greedyParseJSON = false - } - if p.index != 0 && p.prefix == "" { - return nil, "" - } - if p.prefixFound { - // Drop tokens since prefix was found - return nil, "" - } - return nil, s + + calls = append(calls, *call) } - for _, tc := range toolCalls { - tc.Function.Index = p.index - p.index++ + if p.done() { + p.state = toolsState_Done + content = string(p.buffer) + p.buffer = []byte{} } - p.sb.Reset() - return toolCalls, "" + return calls, content } -// NewParser creates a new tool call parser from a template. It extracts the tool call format, -// prefix, and field names from the template to use for parsing tool calls from model output. -// -// Returns an error if the template does not contain valid tool call formatting. -func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { - parsed, err := template.Parse(templateToProcess.Root.String()) - if err != nil { - return nil, err +// findTag searches the buffer to find and handle a tool calling tag +// returning true if the tag was found and false otherwise, and +// a string content signaling any content that should be sent back to the user +func (p *Parser) findTag() (int, bool) { + // First check for complete substring anywhere in s + if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 { + return i, true } - tt, err := toolTemplate(parsed) - if err != nil { - return nil, err + // Then check for partial suffix overlap + max := min(len(p.buffer), len(p.tag)) + for i := max; i > 0; i-- { + if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) { + return len(p.buffer) - i, false + } } - - tp := toolPrefix(templateToProcess) - - name, arguments, err := extractToolArgs(tt) - if err != nil { - return nil, err - } - - return &Parser{ - tmpl: *tt, - sb: strings.Builder{}, - prefix: tp, - greedyParseJSON: true, - name: name, - arguments: arguments, - }, nil + return -1, false +} + +// parseToolCall finds the next complete tool call in the buffer +// incrementing n and advancing the buffer. +func (p *Parser) parseToolCall() *api.ToolCall { + var name string + var args map[string]any + var end int = len(p.buffer) + + // find tool name + var i int + for _, n := range p.names { + if i = bytes.Index(p.buffer, []byte(n)); i != -1 { + if i+len(n) < end { + name = n + end = i + len(n) + } + } + } + + if name == "" { + return nil + } + + if args, i = p.findArguments(); args == nil { + return nil + } + + if i > end { + end = i + } + + tc := &api.ToolCall{ + Function: api.ToolCallFunction{ + Name: name, + Arguments: args, + Index: p.n, + }, + } + + p.n++ + p.buffer = p.buffer[end:] + return tc +} + +// findArguments returns the first object that appears to be +// arguments and the position where the arguments end, returning nil and 0 if +// an invalid JSON object or non-arguments object is found first +func (p *Parser) findArguments() (map[string]any, int) { + if len(p.buffer) == 0 { + return nil, 0 + } + + var braces int + var start int = -1 + var end int + var object []byte + + // find any outer json object + for i, c := range p.buffer { + if c == '{' { + braces++ + if start == -1 { + start = i + } + } + + if c == '}' { + braces-- + if braces == 0 && start != -1 { + end = i + 1 + object = p.buffer[start:end] + break + } + } + } + + if braces > 0 { + return nil, 0 + } + + var data map[string]any + + // not valid json + if err := json.Unmarshal(object, &data); err != nil { + return nil, 0 + } + + var find func(obj any) map[string]any + find = func(obj any) map[string]any { + switch v := obj.(type) { + case map[string]any: + // check if the object keys are valid tool properties + // TODO (jmorganca): check only sets of properties that + // go together instead of the entire set + for _, prop := range p.properties { + if _, exists := v[prop]; exists { + return v + } + } + + for _, value := range v { + if result := find(value); result != nil { + return result + } + } + case []any: + for _, item := range v { + if result := find(item); result != nil { + return result + } + } + } + + return nil + } + + result := find(data) + if result != nil { + return result, end + } + + return nil, 0 +} + +// done checks if the parser is done parsing by looking +// for closing tag. currently only } and ] are supported +// for closing tags as {} or [] pairs may not always +// represent tool calls and we need to send the content back +func (p *Parser) done() bool { + var open, close rune + switch p.tag { + case "{": + open, close = '{', '}' + case "[": + open, close = '[', ']' + default: + return false + } + + var count int + for _, c := range p.buffer { + if c == byte(open) { + count++ + } else if c == byte(close) { + count-- + if count == 0 { + return true + } + } + } + + return false +} + +// Content returns any remaining content that +// should be sent to the user. This should be the empty string +// string unless the tag is { or [ and a tool call was not found +func (p *Parser) Content() string { + if p.n > 0 { + return "" + } + + if p.tag == "{" || p.tag == "[" { + return string(p.buffer) + } + + return "" } diff --git a/tools/tools_test.go b/tools/tools_test.go index 5fee8f57..67864168 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -1,673 +1,805 @@ package tools import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" "testing" + "text/template" "github.com/google/go-cmp/cmp" - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" ) -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) +func TestParser(t *testing.T) { + qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to parse template: %v", err) } - return bytes.NewBuffer(bts) -} + deepseek, err := template.New("deepseek").Parse("{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>{{end}}<|tool▁calls▁end|><|end▁of▁sentence|>{{end}}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + json, err := template.New("json").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + mistral, err := template.New("mistral").Parse(`{{if .ToolCalls}}[TOOL_CALLS] [{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}][/TOOL_CALLS]{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + list, err := template.New("list").Parse(`{{if .ToolCalls}}[{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}]{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_temperature", + Description: "Retrieve the temperature for a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "format": { + Type: api.PropertyType{"string"}, + Description: "The format to return the temperature in", + Enum: []any{"fahrenheit", "celsius"}, + }, + "city": { + Type: api.PropertyType{"string"}, + Description: "The city to get the temperature for", + }, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_conditions", + Description: "Retrieve the current weather conditions for a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The location to get the weather conditions for", + }, + }, + }, + }, + }, + } -func TestParseJSONToolCalls(t *testing.T) { tests := []struct { - name string - input string - nameField string - argsField string - wantToolCalls []api.ToolCall - wantErr error - prefix string + name string + inputs []string + tmpl *template.Template + content string + calls []api.ToolCall }{ { - name: "valid single tool call", - input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "no tool calls - just text", + inputs: []string{"Hello, how can I help you today?"}, + content: "Hello, how can I help you today?", + tmpl: qwen, + calls: nil, + }, + { + name: "empty input", + inputs: []string{""}, + content: "", + tmpl: qwen, + calls: nil, + }, + { + name: "tool call", + inputs: []string{`{"name": "get_conditions", "arguments": {"location": "San Francisco"}}`}, + content: "", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "test_tool", - Arguments: map[string]any{ - "arg1": "value1", + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "San Francisco", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "incomplete JSON", - input: `{"name": "test_tool", "arguments": {"arg1": `, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", - }, - { - name: "invalid JSON", - input: `not json at all`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errInvalidToolCall, - prefix: "", - }, - { - name: "missing required fields", - input: `{"other": "field"}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errInvalidToolCall, - prefix: "", - }, - { - name: "multiple tool calls in array", - input: `[ - {"name": "tool1", "arguments": {"arg1": 1}}, - {"name": "tool2", "arguments": {"arg2": "value"}} - ]`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "text before tool call", + inputs: []string{`Let me check the weather. {"name": "get_temperature", "arguments": {"city": "New York"}}`}, + content: "Let me check the weather. ", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), - }, - }, - }, - { - Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "New York", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "multiple tool calls without array", - input: ` - {"name": "tool1", "arguments": {"arg1": 1}}, - {"name": "tool2", "arguments": {"arg2": "value"}} - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "two tool calls in a list", + inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`}, + content: "", + tmpl: mistral, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + "format": "fahrenheit", }, }, }, { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "multiple tool calls with text after", - input: ` - {"name": "tool1", "arguments": {"arg1": 1}} text - {"name": "tool2", "arguments": {"arg2": "value"}} text - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "two tool calls", + inputs: []string{`Okay, let's call both tools! {"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, + content: "Okay, let's call both tools! ", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + "format": "fahrenheit", }, }, }, { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "second tool call in array", - input: ` - , {"name": "tool2", "arguments": {"arg2": "value"}} - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "deepseek", + inputs: []string{"Wait, I need to call a tool<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"}, + content: "Wait, I need to call a tool", + tmpl: deepseek, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", - }, - // a bad JSON would not return any tool calls or content as it would always accumulate more - { - name: "unbalanced square brackets", - input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", }, { - name: "incomplete square brackets", - input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", - }, - { - name: "nested arrays in arguments", - input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "deepseek incremental", + inputs: []string{ + "Wait", + ", I need", + " to call", + " a tool<|too", + "l▁calls▁begin", + "|>", + "<|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n", + "```json\n", + "{\"city\": \"Tokyo\"}\n", + "```", + "<|tool▁c", "all▁end|>", + "<|tool▁calls▁end|>", + "<|end▁of▁sentence|>", + }, + content: "Wait, I need to call a tool", + tmpl: deepseek, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", + }, + { + name: "json", + inputs: []string{ + "{", + "\"name\": \"get_temperature\",", + "\"arguments\": {", + "\"city\": \"Tokyo\"", + "}", + "}", + }, + content: "", + tmpl: json, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", + }, + }, + }, + }, + }, + { + name: "json maybe a tool call", + inputs: []string{ + "{", + "\"name\": \"get_temperature\",", + "\"arguments\": {", + }, + content: "", + tmpl: json, + calls: nil, + }, + { + name: "json not a tool call", + inputs: []string{ + "{", + "\"name\": \"search\", ", + "\"arguments\": {", + "\"query\": \"What is the capital of Canada?\"", + "}", + "}", + }, + content: "{\"name\": \"search\", \"arguments\": {\"query\": \"What is the capital of Canada?\"}}", + tmpl: json, + calls: nil, + }, + { + name: "json object followed by tool call", + inputs: []string{ + "{\"name\": \"jeff\"}", + "{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + }, + content: "{\"name\": \"jeff\"}{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + tmpl: json, + }, + { + name: "json object followed by tool call split", + inputs: []string{ + "{\"name\": \"jeff\"} {", + "\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + }, + content: "{\"name\": \"jeff\"} {\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + tmpl: json, + }, + { + name: "json code", + inputs: []string{ + "for { fmt.Println(\"hello\") }", + }, + content: "for { fmt.Println(\"hello\") }", + tmpl: json, + }, + { + name: "list multiple", + inputs: []string{ + "[", + "{", + "\"name\": \"get_temperature\", ", + "\"arguments\": {", + "\"city\": \"London\"", + "}", + "},", + "{", + "\"name\": \"get_conditions\", ", + "\"arguments\": {", + "\"location\": \"Tokyo\"", + "}", + "}]", + }, + content: "", + tmpl: list, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, + { + name: "list partial", + inputs: []string{ + "[", + "{", + "\"name\": \"search\", ", + "\"arguments\": {", + "\"query\": \"What is the capital of Canada?\"", + "}", + "}", + }, + content: "", + tmpl: list, + calls: nil, + }, + { + name: "list not a tool call", + inputs: []string{ + "[special", + " del", + "ivery]", + }, + content: "[special delivery]", + tmpl: list, + calls: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) + parser := NewParser(tt.tmpl, tools) - if err != tt.wantErr { - t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) + var calls []api.ToolCall + var content string + for _, input := range tt.inputs { + tcs, c := parser.Add(input) + calls = append(calls, tcs...) + content += c } - if len(gotCalls) != 0 && tt.wantErr != nil { - t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) + if content != tt.content { + t.Errorf("Expected content %q, got %q", tt.content, content) } - if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { - t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + if len(calls) != len(tt.calls) { + t.Fatalf("Expected %d tool calls, got %d", len(tt.calls), len(calls)) + } + + for i, want := range tt.calls { + if diff := cmp.Diff(calls[i], want); diff != "" { + t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff) + } } }) } } -func TestParseToolCalls(t *testing.T) { - p := filepath.Join("testdata") - t1 := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, - } - t2 := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, - } - - cases := []struct { - name string - model string - output string - expectedToolCall []api.ToolCall - expectedTokens string +func TestDone(t *testing.T) { + tests := []struct { + name string + tag string + buffer []byte + want bool }{ { - name: "mistral malformed json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", + name: "empty", + tag: "", + buffer: []byte{}, + want: false, }, { - name: "mistral multiple tool calls without prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", + name: "empty", + tag: "", + buffer: []byte{}, + want: false, }, { - name: "mistral tool calls with text between no prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + name: "json open", + tag: "{", + buffer: []byte("{\"name\": \"get_weather\""), + want: false, }, { - name: "mistral valid json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", + name: "json closed", + tag: "{", + buffer: []byte("{\"name\": \"get_weather\"}"), + want: true, }, { - name: "mistral multiple tool calls with text between and prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, - expectedTokens: "", + name: "json empty", + tag: "{", + buffer: []byte("{}"), + want: true, }, { - name: "mistral incomplete json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: "", + name: "list open", + tag: "[", + buffer: []byte("[{\"name\": \"get_weather\""), + want: false, }, { - name: "mistral invalid tool call with explanatory text no prefix", - model: "mistral", - output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + name: "list closed", + tag: "[", + buffer: []byte("[{\"name\": \"get_weather\"}]"), + want: true, }, { - name: "mistral tool calls without prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "command r plus tool calls with json block format", - model: "command-r-plus", - output: "Action: ```json" + ` - [ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } - ] - ` + "```", - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "firefunction tool calls with functools prefix", - model: "firefunction", - output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "llama3 groq single tool call with xml tags", - model: "llama3-groq-tool-use", - output: ` - {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} - `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "xlam tool calls with wrapper object", - model: "xlam", - output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 single tool call with prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "qwen2.5 multiple tool calls with and without prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 plain text response no tool calls", - model: "qwen2.5", - output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - expectedToolCall: []api.ToolCall{}, - expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - }, - { - name: "qwen2.5 tool calls with trailing text", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens after call", - }, - { - name: "qwen2.5 tool calls with initial text", - model: "qwen2.5", - output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - }, - { - name: "qwen2.5 tool calls with prefix and trailing text", - model: "qwen2.5", - output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 tool calls with prefix and initial text", - model: "qwen2.5", - output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens before call", - }, - { - name: "qwen2.5 tool calls without and with prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 tool calls without and with prefix and text between", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens between", - }, - { - name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", - model: "qwen2.5", - output: `hi [{"options": "foo"}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `hi [{"options": "foo"}]`, - }, - { - name: "qwen2.5 tool calls with prefix and invalid tool call", - model: "qwen2.5", - output: ` [{"options": "foo"}] `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", - model: "qwen3", - output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "Okay, let me think what tool we should use...", - }, - { - name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", - model: "qwen3", - output: `Okay, let me think what tool we should use... { "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "Okay, let me think what tool we should use...", - }, - { - name: "qwen3 empty think prefix without tool prefix and invalid tool call", - model: "qwen3", - output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - }, - { - name: "qwen3 empty think prefix with tool prefix and valid tool call", - model: "qwen3", - output: `{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: ``, - }, - { - name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", - model: "qwen3", - output: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - }, - { - name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", - model: "qwen3", - output: ``, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "qwen3 invalid tool call with malformed tool prefix", - model: "qwen3", - output: ``, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "model with prefix in template, no prefix in output", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model with prefix in template, prefix in output", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output", - model: "llama3.2", - output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, single tool call", - model: "llama3.2", - output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "model without prefix in template, prefix in output, multiple tool calls in list", - model: "llama3.2", - output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model without prefix in template, prefix in output, individual tool calls", - model: "llama3.2", - output: ` {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model with prefix in template, no prefix in output, tokens before", - model: "qwen2.5", - output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - }, - { - name: "model with prefix in template, prefix in output, tokens after", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, tokens after", - model: "llama3.2", - output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, tokens before", - model: "llama3.2", - output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: `some tokens before`, - }, - { - name: "model without prefix in template, prefix in output, tokens after", - model: "llama3.2", - output: ` - [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model without without prefix, match all jsons", - model: "llama3.2", - output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "model outputs some text", - }, - { - name: "model flushes tokens if tool call doesn't match", - model: "llama3.2", - output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, - }, - { - name: "model flushes tokens if tool call doesn't match array", - model: "llama3.2", - output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, + name: "list empty", + tag: "[", + buffer: []byte("[]"), + want: true, }, } - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Parser{ + tag: tt.tag, + buffer: tt.buffer, + } + got := parser.done() + if got != tt.want { + t.Errorf("done() = %t, want %t", got, tt.want) + } + }) + } +} + +func TestContent(t *testing.T) { + tests := []struct { + name string + tag string + content []byte + want string + n int + }{ + { + name: "empty", + content: []byte{}, + tag: "{", + want: "", + n: 0, + }, + { + name: "tag", + tag: "", + content: []byte("{\"name\": \"get_temperature\""), + want: "", + n: 0, + }, + { + name: "json object", + tag: "{", + content: []byte("{\"name\": \"get_temperature\"}"), + want: "{\"name\": \"get_temperature\"}", + n: 0, + }, + { + name: "json object after called", + tag: "{", + content: []byte("{\"hello\": \"world\"}"), + want: "{\"hello\": \"world\"}", + n: 0, + }, + { + name: "json object after called", + tag: "{", + content: []byte("{\"hello\": \"world\"}"), + want: "", + n: 1, + }, + { + name: "list", + tag: "[", + content: []byte("[{\"name\": \"get_temperature\"}]"), + want: "[{\"name\": \"get_temperature\"}]", + n: 0, + }, + { + name: "code", + tag: "{", + content: []byte("{ fmt.Println(\"hello\")"), + want: "{ fmt.Println(\"hello\")", + n: 0, + }, } - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Parser{ + tag: tt.tag, + buffer: tt.content, + n: tt.n, + } + got := parser.Content() + if got != tt.want { + t.Errorf("Content() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestFindTag(t *testing.T) { + cases := []struct { + name string + buffer []byte + tag string + i int + found bool + }{ + { + name: "no overlap", + buffer: []byte("hello world"), + tag: "", + i: -1, + found: false, + }, + { + name: "full overlap", + buffer: []byte(""), + tag: "", + i: 0, + found: true, + }, + { + name: "whitespace", + buffer: []byte(" \n {\"name\": \"bob\"}"), + tag: "", + i: 4, + found: true, + }, + { + name: "over", + buffer: []byte("{\"name\""), + tag: "", + i: 0, + found: true, + }, + { + name: "partial overlap", + buffer: []byte("text "), + tag: "", + i: 5, + found: true, + }, + { + name: "overlap with extra", + buffer: []byte(""), + tag: "", + i: 0, + found: true, + }, + { + name: "delimiter longer than string", + buffer: []byte(""), + tag: "", + i: -1, + found: false, + }, + { + name: "empty string", + buffer: []byte{}, + tag: "", + i: -1, + found: false, + }, + { + name: "single char overlap", + buffer: []byte("test<"), + tag: "", + i: 4, + found: false, + }, + { + name: "partial tool call", + buffer: []byte("hello ", + i: 6, + found: false, + }, + { + name: "square bracket", + buffer: []byte("calling tools: ["), + tag: "[", + i: 15, + found: true, + }, + { + name: "bracket", + buffer: []byte("{\"name\": \"bob\""), + tag: "{", + i: 0, + found: true, + }, + { + name: "bracket with whitespace", + buffer: []byte("\n\n{\n\"name\": \"bob\""), + tag: "{", + i: 2, + found: true, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) - if err != nil { - t.Fatal(err) + parser := &Parser{ + tag: tt.tag, + buffer: tt.buffer, + n: 0, + } + i, found := parser.findTag() + if i != tt.i { + t.Errorf("findTag(%q, %q) = %d; want %d", tt.buffer, tt.tag, i, tt.i) + } + if found != tt.found { + t.Errorf("findTag(%q, %q) = %t; want %t", tt.buffer, tt.tag, found, tt.found) + } + }) + } +} + +func TestFindArguments(t *testing.T) { + tests := []struct { + name string + buffer []byte + want map[string]any + }{ + { + name: "empty string", + buffer: []byte{}, + want: nil, + }, + { + name: "whitespace only", + buffer: []byte(" \n\t "), + want: nil, + }, + { + name: "unbalanced braces - missing closing", + buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`), + want: nil, + }, + { + name: "unbalanced braces - extra closing", + buffer: []byte(`{"format": "fahrenheit"}}`), + want: map[string]any{ + "format": "fahrenheit", + }, + }, + { + name: "invalid JSON", + buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`), + want: nil, + }, + { + name: "valid json", + buffer: []byte(`{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "valid arguments with special tokens", + buffer: []byte(`[tool]get_temperature[args]{"format": "fahrenheit", "location": "San Francisco, CA"}[end]`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "valid arguments in array", + buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "nested deep", + buffer: []byte(`{"function": {"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "one arg", + buffer: []byte(`get_weather({"location": "San Francisco, CA"})`), + want: map[string]any{ + "location": "San Francisco, CA", + }, + }, + { + name: "two args", + buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`), + want: map[string]any{ + "location": "San Francisco, CA", + "format": "fahrenheit", + }, + }, + { + name: "deepseek", + buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"), + want: map[string]any{ + "location": "Tokyo", + }, + }, + } + + for _, tt := range tests { + parser := &Parser{ + buffer: tt.buffer, + properties: []string{"format", "location"}, + } + + t.Run(tt.name, func(t *testing.T) { + got, _ := parser.findArguments() + + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff) } - - t.Run("template", func(t *testing.T) { - actual := &bytes.Buffer{} // Create new buffer for each test - if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - tp, err := NewParser(tmpl.Template) - if err != nil { - t.Fatal(err) - } - got := []api.ToolCall{} - var gotTokens strings.Builder - - tokens := strings.Fields(tt.output) - for _, tok := range tokens { - s := " " + tok - - toolCalls, content := tp.Add(s) - if len(content) > 0 { - gotTokens.WriteString(content) - } else if len(toolCalls) > 0 { - got = append(got, toolCalls...) - } - } - - // Compare tool calls if we expect any - if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { - t.Errorf("tool calls mismatch (-got +want):\n%s", diff) - } - - // Compare tokens if we expect any - stripped := strings.TrimSpace(gotTokens.String()) - if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { - t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) - t.Errorf("tokens mismatch (-got +want):\n%s", diff) - } - }) }) } } diff --git a/tools/tools_utils.go b/tools/tools_utils.go deleted file mode 100644 index b6f80729..00000000 --- a/tools/tools_utils.go +++ /dev/null @@ -1,222 +0,0 @@ -package tools - -import ( - "bytes" - "encoding/json" - "errors" - "log/slog" - "slices" - "strings" - gotmpl "text/template" - "text/template/parse" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" -) - -// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. -// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any -// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. -// -// Returns: -// - string: The extracted text following the first ".ToolCalls" condition found -// - bool: Whether a ".ToolCalls" condition was found in the template -func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { - if tmpl == nil || tmpl.Tree == nil { - slog.Debug("template or tree is nil") - return "", false - } - - var result string - var found bool - - var walk func(nodes []parse.Node) - walk = func(nodes []parse.Node) { - for _, node := range nodes { - if found { - return - } - - switch n := node.(type) { - case *parse.IfNode: - if isToolCallsNode(n) { - // Collect immediate TextNode(s) at start of IfNode's list - var sb strings.Builder - for _, innerNode := range n.List.Nodes { - if tn, ok := innerNode.(*parse.TextNode); ok { - sb.Write(tn.Text) - } else { - // Stop at first non-text node - break - } - } - result = sb.String() - found = true - return - } - // Recurse into child nodes - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.ListNode: - walk(n.Nodes) - case *parse.RangeNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.WithNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - default: - // Continue to next node - continue - } - } - } - - walk(tmpl.Tree.Root.Nodes) - return result, found -} - -// isToolCallsNode detects if a node's condition includes ".ToolCalls" -func isToolCallsNode(n *parse.IfNode) bool { - for _, cmd := range n.Pipe.Cmds { - for _, arg := range cmd.Args { - if field, ok := arg.(*parse.FieldNode); ok { - if slices.Contains(field.Ident, "ToolCalls") { - return true - } - } - } - } - return false -} - -func toolPrefix(tmpl *gotmpl.Template) string { - tokenText, ok := extractToolCallsFormat(tmpl) - if !ok { - return "" - } - tokenText = strings.TrimSpace(tokenText) - tokenText = strings.ReplaceAll(tokenText, "\r", "") - tokenText = strings.ReplaceAll(tokenText, "\n", " ") - - return tokenText -} - -// toolTemplate creates a subtree from the node that ranges over .ToolCalls -// -// Returns: -// - *gotmpl.Template: The subtree containing the .ToolCalls range -// - error: Error if parsing failed -func toolTemplate(t *template.Template) (*gotmpl.Template, error) { - tmpl := t.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl == nil { - return nil, errors.New("failed to find tool template") - } - - return tmpl, nil -} - -// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins -// -// Returns: -// - int: The starting index in s where the suffix overlap begins -func suffixOverlap(s, prefix string) int { - max := min(len(prefix), len(s)) - for i := max; i > 0; i-- { - if strings.HasSuffix(s, prefix[:i]) { - return len(s) - i - } - } - return -1 -} - -// extractToolArgs executes a template with a known tool call format to extract the name and arguments -// -// Returns: -// - string: The name of the tool call -// - string: The arguments of the tool call -// - error: Error if parsing failed -func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return "", "", err - } - - // Extract JSON object between curly braces - // JSON arrays are also valid as they will not be repeated in the template - output := b.String() - start := strings.Index(output, "{") - end := strings.LastIndex(output, "}") - if start == -1 || end == -1 || start > end { - return "", "", errors.New("no valid JSON object found in template output") - } - jsonStr := output[start : end+1] - - var obj map[string]any - if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { - return "", "", err - } - - // Find name and arguments fields - for k, v := range obj { - if str, ok := v.(string); ok && str == "@@name@@" { - name = k - } else if _, ok := v.(map[string]any); ok { - arguments = k - } - } - - if name == "" || arguments == "" { - slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) - return "", "", errors.New("missing required fields in tool call template") - } - - return name, arguments, nil -} - -// collect recursively traverses an object to collect all nested maps -// -// Returns: -// - []map[string]any: A slice of all nested maps found in the object -func collect(obj any) []map[string]any { - var all []map[string]any - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - default: - return nil - } - - return all -} diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go deleted file mode 100644 index e346117a..00000000 --- a/tools/tools_utils_test.go +++ /dev/null @@ -1,497 +0,0 @@ -package tools - -import ( - "testing" - gotmpl "text/template" - - "github.com/ollama/ollama/template" -) - -func TestExtractToolCallsFormat(t *testing.T) { - cases := []struct { - name string - template string - want string - found bool - }{ - { - name: "nil template", - template: "", - want: "", - found: false, - }, - { - name: "basic tool call with text", - template: "{{if .ToolCalls}}Hello world{{end}}", - want: "Hello world", - found: true, - }, - { - name: "tool call with json format", - template: "{{if .ToolCalls}}```json\n{{end}}", - want: "```json\n", - found: true, - }, - { - name: "tool call in range", - template: "{{range .ToolCalls}}tool: {{.}}{{end}}", - want: "", - found: false, - }, - { - name: "tool call with multiple text nodes", - template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", - want: "First text", - found: true, - }, - { - name: "nested if without tool calls", - template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", - want: "", - found: false, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tc.template) - if err != nil && tc.template != "" { - t.Fatalf("failed to parse template: %v", err) - } - - got, found := extractToolCallsFormat(tmpl) - if got != tc.want { - t.Errorf("got text %q, want %q", got, tc.want) - } - if found != tc.found { - t.Errorf("got found %v, want %v", found, tc.found) - } - }) - } -} - -func TestToolPrefix(t *testing.T) { - cases := []struct { - name string - template string - want string - }{ - { - name: "basic tool call with action prefix", - template: "{{if .ToolCalls}}Action: ```json{{end}}", - want: "Action: ```json", - }, - { - name: "incomplete functools bracket", - template: "{{if .ToolCalls}}functools[{{end}}", - want: "functools[", - }, - { - name: "tool call with angle brackets", - template: "{{if .ToolCalls}}Hello, world! {{end}}", - want: "Hello, world! ", - }, - { - name: "multiple tool call formats", - template: "{{if .ToolCalls}}[tool_call] {{end}}", - want: "[tool_call] ", - }, - { - name: "single angle bracket tool call", - template: "{{if .ToolCalls}}{{end}}", - want: "", - }, - { - name: "incomplete angle bracket after tool call", - template: "{{if .ToolCalls}}[tool_call] <{{end}}", - want: "[tool_call] <", - }, - { - name: "angle bracket prefix with tool call", - template: "{{if .ToolCalls}}> {{end}}", - want: "> ", - }, - { - name: "uppercase tool call with incomplete bracket", - template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", - want: "[TOOL_CALL] [", - }, - { - name: "uppercase tool call with adjacent bracket", - template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", - want: "[TOOL_CALL][", - }, - { - name: "tool call with pipe delimiters", - template: "{{if .ToolCalls}}<|tool_call|>{{end}}", - want: "<|tool_call|>", - }, - { - name: "tool with no prefix", - template: "{{if .ToolCalls}}{{end}}", - want: "", - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - got := toolPrefix(tmpl) - if got != tt.want { - t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) - } - }) - } -} - -func TestToolTemplate(t *testing.T) { - cases := []struct { - name string - template string - want bool - }{ - { - name: "basic tool call range", - template: "{{range .ToolCalls}}test{{end}}", - want: true, - }, - { - name: "no tool calls", - template: "{{range .Other}}test{{end}}", - want: false, - }, - { - name: "nested tool calls", - template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", - want: true, - }, - { - name: "empty template", - template: "", - want: false, - }, - { - name: "tool calls in if statement", - template: "{{if .ToolCalls}}test{{end}}", - want: false, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - parsed, err := template.Parse(tmpl.Root.String()) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - _, err = toolTemplate(parsed) - if err != nil && tt.want { - t.Errorf("toolTemplate() = %v; want %v", err, tt.want) - } - }) - } -} - -func TestSuffixOverlap(t *testing.T) { - cases := []struct { - name string - s string - d string - want int - }{ - { - name: "no overlap", - s: "hello world", - d: "", - want: -1, - }, - { - name: "full overlap", - s: "", - d: "", - want: 0, - }, - { - name: "partial overlap", - s: "text ", - d: "", - want: 5, - }, - { - name: "delimiter longer than string", - s: "", - d: "", - want: -1, - }, - { - name: "empty string", - s: "", - d: "", - want: -1, - }, - { - name: "empty delimiter", - s: "", - d: "", - want: -1, - }, - { - name: "single char overlap", - s: "test<", - d: "", - want: 4, - }, - { - name: "partial tool call", - s: "hello ", - want: 6, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got := suffixOverlap(tt.s, tt.d) - if got != tt.want { - t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) - } - }) - } -} - -func TestExtractToolArgs(t *testing.T) { - cases := []struct { - name string - template string - wantName string - wantArgs string - wantErr bool - }{ - { - name: "basic tool call", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`, - wantName: "name", - wantArgs: "parameters", - wantErr: false, - }, - { - name: "tool call with whitespace", - template: `{{range .ToolCalls}} - {"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}} -{{end}}`, - wantName: "name", - wantArgs: "parameters", - wantErr: false, - }, - { - name: "tool call with extra content", - template: `Before {{range .ToolCalls}} -{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "no tool calls", - template: `{{if .Something}}no tools here{{end}}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "empty template", - template: ``, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "prefix within tool call", - template: `{{- if .ToolCalls }} -{{ range .ToolCalls }} - -{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }}{{- end }}`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "JSON array", - template: `{{ range .ToolCalls }} -[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "invalid JSON", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "missing name field", - template: `{{ range .ToolCalls }} -{"parameters": {{ .Function.Arguments }}}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "missing arguments field", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}"}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "malformed JSON", - template: `{{ range .ToolCalls }} -{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - gotName, gotArgs, err := extractToolArgs(tmpl) - if (err != nil) != tt.wantErr { - t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr) - return - } - if err != nil { - return - } - - if gotName != tt.wantName { - t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName) - } - if gotArgs != tt.wantArgs { - t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs) - } - }) - } -} - -func TestCollect(t *testing.T) { - cases := []struct { - name string - obj any - want []map[string]any - }{ - { - name: "simple map", - obj: map[string]any{ - "key": "value", - }, - want: []map[string]any{ - {"key": "value"}, - }, - }, - { - name: "nested map", - obj: map[string]any{ - "outer": map[string]any{ - "inner": "value", - }, - }, - want: []map[string]any{ - {"outer": map[string]any{"inner": "value"}}, - {"inner": "value"}, - }, - }, - { - name: "array of maps", - obj: []any{ - map[string]any{"key1": "val1"}, - map[string]any{"key2": "val2"}, - }, - want: []map[string]any{ - {"key1": "val1"}, - {"key2": "val2"}, - }, - }, - { - name: "deeply nested", - obj: map[string]any{ - "l1": map[string]any{ - "l2": map[string]any{ - "l3": "value", - }, - }, - }, - want: []map[string]any{ - {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, - {"l2": map[string]any{"l3": "value"}}, - {"l3": "value"}, - }, - }, - { - name: "non-map value", - obj: "string", - want: nil, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got := collect(tt.obj) - if len(got) != len(tt.want) { - t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) - return - } - - // Compare each map in the result - for i := range tt.want { - if !mapsEqual(got[i], tt.want[i]) { - t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) - } - } - }) - } -} - -// mapsEqual compares two maps for deep equality -func mapsEqual(m1, m2 map[string]any) bool { - if len(m1) != len(m2) { - return false - } - for k, v1 := range m1 { - v2, ok := m2[k] - if !ok { - return false - } - switch val1 := v1.(type) { - case map[string]any: - val2, ok := v2.(map[string]any) - if !ok || !mapsEqual(val1, val2) { - return false - } - default: - if v1 != v2 { - return false - } - } - } - return true -} From 5a8eb0e1510a5a35b80649f2b88e9231716b6850 Mon Sep 17 00:00:00 2001 From: Phil Date: Sat, 14 Jun 2025 17:54:03 +0200 Subject: [PATCH 23/26] readme: add GPTranslate to community integrations (#11071) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ffaec628..a216589a 100644 --- a/README.md +++ b/README.md @@ -407,6 +407,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers) - [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI) - [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.) +- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.) ### Cloud From 502028968ddca04bd19c0859a73fb4e0cbeac3e1 Mon Sep 17 00:00:00 2001 From: NGC13009 Date: Mon, 16 Jun 2025 12:27:49 +0800 Subject: [PATCH 24/26] readme: add ollama-launcher to community integrations (#11080) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a216589a..e148f9af 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI) - [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.) - [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.) +- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.) ### Cloud From a6fbfc880c3de9b57e341db374907e2fedda9fa6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 16 Jun 2025 10:42:32 -0700 Subject: [PATCH 25/26] gguf: fix write order (#11068) * ggml: test write gguf order * ggml: fix write tensor order --- fs/ggml/gguf.go | 12 ++--- fs/ggml/gguf_test.go | 110 +++++++++++++++++++++++++------------------ 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index 8e75625e..33b596cc 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { return err } - keys := slices.Collect(maps.Keys(kv)) - slices.Sort(keys) - - for _, key := range keys { + for _, key := range slices.Sorted(maps.Keys(kv)) { if err := ggufWriteKV(f, key, kv[key]); err != nil { return err } } slices.SortStableFunc(ts, func(a, b *Tensor) int { - if i, j := a.block(), b.block(); i < 0 && j > 0 { - return 1 - } else if i > 0 && j < 0 { - return -1 - } else { + if i, j := a.block(), b.block(); i > 0 && j > 0 { return cmp.Compare(i, j) } + return cmp.Compare(a.Name, b.Name) }) var s uint64 diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go index 0e071800..bf767918 100644 --- a/fs/ggml/gguf_test.go +++ b/fs/ggml/gguf_test.go @@ -2,62 +2,82 @@ package ggml import ( "bytes" + "math/rand/v2" "os" - "slices" + "strings" "testing" "github.com/google/go-cmp/cmp" ) func TestWriteGGUF(t *testing.T) { - w, err := os.CreateTemp(t.TempDir(), "*.bin") - if err != nil { - t.Fatal(err) - } - defer w.Close() + r := rand.New(rand.NewPCG(0, 0)) + for range 8 { + t.Run("shuffle", func(t *testing.T) { + t.Parallel() - if err := WriteGGUF(w, KV{ - "general.alignment": uint32(16), - }, []*Tensor{ - {Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - {Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - {Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - {Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - {Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - {Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, - }); err != nil { - t.Fatal(err) - } + ts := []*Tensor{ + {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, + {Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, + {Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, + } - r, err := os.Open(w.Name()) - if err != nil { - t.Fatal(err) - } - defer r.Close() + r.Shuffle(len(ts), func(i, j int) { + ts[i], ts[j] = ts[j], ts[i] + }) - ff, err := Decode(r, 0) - if err != nil { - t.Fatal(err) - } + w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin") + if err != nil { + t.Fatal(err) + } + defer w.Close() - if diff := cmp.Diff(ff.KV(), KV{ - "general.alignment": uint32(16), - "general.parameter_count": uint64(36), - }); diff != "" { - t.Errorf("Mismatch (-want +got):\n%s", diff) - } + if err := WriteGGUF(w, KV{ + "general.alignment": uint32(16), + }, ts); err != nil { + t.Fatal(err) + } - if diff := cmp.Diff(ff.Tensors(), Tensors{ - Offset: 336, - items: []*Tensor{ - {Name: "test.0", Offset: 0, Shape: []uint64{2, 3}}, - {Name: "test.1", Offset: 32, Shape: []uint64{2, 3}}, - {Name: "test.2", Offset: 64, Shape: []uint64{2, 3}}, - {Name: "test.3", Offset: 96, Shape: []uint64{2, 3}}, - {Name: "test.4", Offset: 128, Shape: []uint64{2, 3}}, - {Name: "test.5", Offset: 160, Shape: []uint64{2, 3}}, - }, - }, cmp.AllowUnexported(Tensors{})); diff != "" { - t.Errorf("Mismatch (-want +got):\n%s", diff) + r, err := os.Open(w.Name()) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + ff, err := Decode(r, 0) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(KV{ + "general.alignment": uint32(16), + "general.parameter_count": uint64(54), + }, ff.KV()); diff != "" { + t.Errorf("Mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(Tensors{ + Offset: 608, + items: []*Tensor{ + {Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}}, + {Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}}, + {Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}}, + {Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}}, + {Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}}, + {Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}}, + {Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}}, + {Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}}, + {Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}}, + }, + }, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" { + t.Errorf("Mismatch (-want +got):\n%s", diff) + } + }) } } From 9e125d884cf995dfae7fcd74690d525e4326a517 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 16 Jun 2025 16:03:16 -0700 Subject: [PATCH 26/26] model: treat 'user defined' tokens as special tokens (#11077) --- model/vocabulary.go | 2 +- model/vocabulary_test.go | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 model/vocabulary_test.go diff --git a/model/vocabulary.go b/model/vocabulary.go index 24adbaca..a86de58d 100644 --- a/model/vocabulary.go +++ b/model/vocabulary.go @@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string { func (v *Vocabulary) SpecialVocabulary() []string { v.specialOnce.Do(func() { for i := range v.Values { - if v.Types[i] == TOKEN_TYPE_CONTROL { + if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED { v.special = append(v.special, v.Values[i]) } } diff --git a/model/vocabulary_test.go b/model/vocabulary_test.go new file mode 100644 index 00000000..46f0ead2 --- /dev/null +++ b/model/vocabulary_test.go @@ -0,0 +1,16 @@ +package model + +import "testing" + +func TestVocabulary_SpecialVocabulary(t *testing.T) { + vocab := &Vocabulary{ + Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"}, + Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL}, + } + + specialVocab := vocab.SpecialVocabulary() + + if len(specialVocab) != 4 { + t.Errorf("expected 4 special tokens, got %d", len(specialVocab)) + } +}