From e7f56ef3d8ac70280b05ec66989dfe0845f8f114 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 18 Sep 2025 14:55:59 -0700 Subject: [PATCH 01/16] harmony: remove special casing in routes.go Now that we have a built-in parser abstraction, which was introduced in , we can modify our harmony parser to match this and then get rid of nearly all of the harmony-specific logic in routes.go. We do have a small amount of code that turns the parser on by default if the architecture matches and no other built-in parser was provided. The built-in parser interface was modified in order to handle harmony's prefill and tool name translation requirements. --- .gitignore | 1 + harmony/harmonyparser.go | 77 +++++++++++++++++++++ model/parsers/parsers.go | 16 ++++- model/parsers/qwen3coder.go | 14 ++-- server/routes.go | 132 ++++++++++++------------------------ 5 files changed, 144 insertions(+), 96 deletions(-) diff --git a/.gitignore b/.gitignore index 3a2af0bd..eabf94c2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist build .cache +.gocache *.exe .idea test_data diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index a51819dd..b365b763 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -1,6 +1,7 @@ package harmony import ( + "encoding/json" "fmt" "log/slog" "strings" @@ -265,6 +266,8 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap + toolAccumulator *HarmonyToolCallAccumulator + convertedTools map[string]struct{} } // NewHarmonyMessageHandler creates a new message handler @@ -277,6 +280,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), + convertedTools: make(map[string]struct{}), } } @@ -384,6 +388,79 @@ func NewFunctionNameMap() *FunctionNameMap { } } +// Init initializes the handler with tools and optional last message +// Implements the Parser interface +func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + // Initialize the harmony parser + if h.HarmonyParser == nil { + h.HarmonyParser = &HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + } + + // Handle prefill for chat mode + if lastMessage != nil { + h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + } else { + h.HarmonyParser.AddImplicitStart() + } + + // Initialize tool accumulator + h.toolAccumulator = h.CreateToolParser() + + // Process tools and return renamed versions + if len(tools) == 0 { + return tools + } + + processedTools := make([]api.Tool, len(tools)) + copy(processedTools, tools) + for i, tool := range processedTools { + if tool.Function.Name != "" { + processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + h.convertedTools[tool.Function.Name] = struct{}{} + } + } + return processedTools +} + +// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls +func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + content, thinking, toolContent := h.AddContent(s, h.toolAccumulator) + if toolContent != "" { + h.toolAccumulator.Add(toolContent) + } + + // tool calls always happen one at a time, and always at the end of a message, + // so for simplicity we defer parsing them until we know we're done + if done { + toolName, raw := h.toolAccumulator.Drain() + if toolName != nil { + name := strings.TrimPrefix(*toolName, "functions.") + name = h.FunctionNameMap.OriginalFromConverted(name) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err) + } + calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}}) + } + } + + return content, thinking, calls, nil +} + +// HasToolSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasToolSupport() bool { + return true +} + +// HasThinkingSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasThinkingSupport() bool { + return true +} + func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { harmonyFunctionName := m.deriveName(userFunctionName) m.userToHarmony[userFunctionName] = harmonyFunctionName diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index e6dbd1f4..a1d4e812 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -2,10 +2,16 @@ package parsers import ( "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" ) type Parser interface { - Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) + // Init initializes the parser with tools and optional last message for chat prefill + // Returns processed tools if the parser needs to modify them (e.g., harmony renames them) + Init(tools []api.Tool, lastMessage *api.Message) []api.Tool + // Add processes streamed content and returns parsed content, thinking, and tool calls + // The done flag indicates if this is the last chunk (used for draining accumulators) + Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) HasToolSupport() bool HasThinkingSupport() bool } @@ -17,6 +23,8 @@ func ParserForName(name string) Parser { return parser case "passthrough": return &PassthroughParser{} + case "harmony": + return harmony.NewHarmonyMessageHandler() default: return nil } @@ -24,7 +32,11 @@ func ParserForName(name string) Parser { type PassthroughParser struct{} -func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { +func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + return tools // passthrough doesn't modify tools +} + +func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { return s, "", nil, nil } diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index b0e8ec48..b3629a5c 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -31,6 +31,7 @@ const ( type Qwen3CoderParser struct { state qwenParserState acc strings.Builder + tools []api.Tool } func (p *Qwen3CoderParser) HasToolSupport() bool { @@ -41,7 +42,12 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool { return false } -func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { +func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + p.tools = tools + return tools // Qwen doesn't modify tools +} + +func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { p.acc.WriteString(s) events := p.parseEvents() @@ -51,7 +57,7 @@ func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thin for _, event := range events { switch event := event.(type) { case qwenEventRawToolCall: - toolCall, err := parseToolCall(event, tools) + toolCall, err := parseToolCall(event, p.tools) if err != nil { slog.Warn("qwen tool call parsing failed", "error", err) return "", "", nil, err @@ -359,7 +365,7 @@ func parseValue(raw string, paramType api.PropertyType) any { // Try array if typeSet["array"] { - var arr []interface{} + var arr []any if err := json.Unmarshal([]byte(raw), &arr); err == nil { return arr } @@ -371,7 +377,7 @@ func parseValue(raw string, paramType api.PropertyType) any { // Try object if typeSet["object"] { - var obj map[string]interface{} + var obj map[string]any if err := json.Unmarshal([]byte(raw), &obj); err == nil { return obj } diff --git a/server/routes.go b/server/routes.go index c0204531..3e940702 100644 --- a/server/routes.go +++ b/server/routes.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/model/parsers" @@ -288,17 +287,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(m) && !req.Raw - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStart() - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + var builtinParser parsers.Parser + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + if !req.Raw && m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // no tools or last message for generate endpoint + builtinParser.Init(nil, nil) + } + } + + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -422,7 +425,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var thinkingState *thinking.Parser - if !useHarmony { + if builtinParser == nil { openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { thinkingState = &thinking.Parser{ @@ -459,11 +462,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + if builtinParser != nil { + content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } res.Response = content res.Thinking = thinking - harmonyToolParser.Add(toolContent) + if cr.Done && len(toolCalls) > 0 { + res.ToolCalls = toolCalls + } } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking @@ -475,26 +484,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { - if useHarmony { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - - res.ToolCalls = append(res.ToolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -509,7 +498,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - if useHarmony { + if builtinParser != nil { // only send messages with meaningful content (empty messages confuse clients) if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { ch <- res @@ -1853,32 +1842,23 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var builtinParser parsers.Parser - if m.Config.Parser != "" { - builtinParser = parsers.ParserForName(m.Config.Parser) + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony" - + var builtinParser parsers.Parser processedTools := req.Tools - if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - // make a copy of tools to pass to the chat prompt. Function names may be - // renamed to be valid Harmony function names. - processedTools = make([]api.Tool, len(req.Tools)) - copy(processedTools, req.Tools) - for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // Determine last message for chat prefill + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + // Initialize parser and get processed tools + processedTools = builtinParser.Init(req.Tools, lastMessage) } } @@ -1902,8 +1882,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -1922,7 +1902,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } var toolParser *tools.Parser - if len(req.Tools) > 0 && !useHarmony { + if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) { toolParser = tools.NewParser(m.Template.Template, req.Tools) } @@ -1954,38 +1934,10 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) - - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } - } - - // only send messages with meaningful content (empty messages confuse clients) - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { - ch <- res - } - - return - } else if builtinParser != nil { + if builtinParser != nil { slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) + content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) if err != nil { ch <- gin.H{"error": err.Error()} return From ae5c33008e53c1db1465b722b60835916e375d15 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 19 Sep 2025 15:49:56 -0700 Subject: [PATCH 02/16] docs: move turbo.md to cloud.md --- docs/{turbo.md => cloud.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{turbo.md => cloud.md} (100%) diff --git a/docs/turbo.md b/docs/cloud.md similarity index 100% rename from docs/turbo.md rename to docs/cloud.md From af060eb2508e8bed25241163243bdd7471cb7fd6 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 19 Sep 2025 15:50:41 -0700 Subject: [PATCH 03/16] docs: update cloud.md for cloud models --- docs/cloud.md | 113 ++++++++++---------------------------------------- 1 file changed, 23 insertions(+), 90 deletions(-) diff --git a/docs/cloud.md b/docs/cloud.md index d75d9557..300e6f5e 100644 --- a/docs/cloud.md +++ b/docs/cloud.md @@ -1,107 +1,40 @@ -# Turbo +# Cloud -> ⚠️ Turbo is preview +| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud). -Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware. +## Cloud Models -Currently, the following models are available in Turbo: +[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer. -- `gpt-oss:20b` -- `gpt-oss:120b` +Ollama currently supports the following cloud models, with more coming soon: -## Get started +- `gpt-oss:20b-cloud` +- `gpt-oss:120b-cloud` +- `deepseek-v3.1:671b-cloud` +- `qwen3-coder:480b-cloud` -### Ollama for macOS & Windows +### Get started -Download Ollama +To run a cloud model, open the terminal and run: -- Select a model such as `gpt-oss:20b` or `gpt-oss:120b` -- Click on **Turbo**. You’ll be prompted to create an account or sign in - -### Ollama’s CLI - -- [Sign up](https://ollama.com/signup) for an Ollama account -- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys). - - On macOS and Linux: - - ```shell - cat ~/.ollama/id_ed25519.pub - ``` - - On Windows: - - ``` - type "%USERPROFILE%\.ollama\id_ed25519.pub" - ``` - -- Then run a model setting `OLLAMA_HOST` to `ollama.com`: - ```shell - OLLAMA_HOST=ollama.com ollama run gpt-oss:120b - ``` - -### Ollama’s Python library - -- Download Ollama's [Python library](https://github.com/ollama/ollama-python) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```python -from ollama import Client - -client = Client( - host="https://ollama.com", - headers={'Authorization': ''} -) - -messages = [ - { - 'role': 'user', - 'content': 'Why is the sky blue?', - }, -] - -for part in client.chat('gpt-oss:120b', messages=messages, stream=True): - print(part['message']['content'], end='', flush=True) +``` +ollama run gpt-oss:120b-cloud ``` -### Ollama’s JavaScript library +To run cloud models with integrations that work with Ollama, first download the cloud model: -- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```typescript -import { Ollama } from 'ollama'; - -const ollama = new Ollama({ - host: 'https://ollama.com', - headers: { - Authorization: "Bearer " - } -}); - -const response = await ollama.chat({ - model: 'gpt-oss:120b', - messages: [{ role: 'user', content: 'Explain quantum computing' }], - stream: true -}); - -for await (const part of response) { - process.stdout.write(part.message.content) -} +``` +ollama pull qwen3-coder:480b-cloud ``` -### Community integrations +Then sign in to Ollama: -Turbo mode is also compatible with several community integrations. +``` +ollama signin +``` -#### Open WebUI +Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling. -- Go to **settings** → **Admin settings** → **Connections** -- Under **Ollama API,** click **+** -- For the **URL** put `https://ollama.com` -- For the **API key,** create an API key on https://ollama.com/settings/keys and add it. -- Click **Save** +## Cloud API access -Now, if you navigate to the model selector, Turbo models should be available under **External**. +Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud). From c23e6f4cae3cbf62db68c2c9bf993925626fbe7c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 22 Sep 2025 11:23:14 -0700 Subject: [PATCH 04/16] tests: add single threaded history test (#12295) * tests: add single threaded history test Also tidies up some existing tests to handle more model output variation * test: add support for testing specific architectures --- integration/README.md | 3 + integration/api_test.go | 14 ++--- integration/basic_test.go | 8 +-- integration/context_test.go | 98 +++++++++++++++++++++++++++++- integration/library_models_test.go | 17 +++++- integration/model_arch_test.go | 5 +- integration/model_perf_test.go | 34 ++++++++--- integration/quantization_test.go | 5 +- integration/utils_test.go | 28 +++++++-- 9 files changed, 173 insertions(+), 39 deletions(-) diff --git a/integration/README.md b/integration/README.md index e52ba71e..1dfd0e35 100644 --- a/integration/README.md +++ b/integration/README.md @@ -12,3 +12,6 @@ The integration tests have 2 modes of operating. > [!IMPORTANT] > Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. + + +Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL` \ No newline at end of file diff --git a/integration/api_test.go b/integration/api_test.go index c39192c9..48572085 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue? be brief", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for generate") @@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) { Messages: []api.Message{ { Role: "user", - Content: "why is the sky blue? be brief", + Content: blueSkyPrompt, }, }, Options: map[string]interface{}{ @@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) { "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for chat") diff --git a/integration/basic_test.go b/integration/basic_test.go index 60cff172..0a6b9253 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } func TestUnicode(t *testing.T) { @@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) { req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } diff --git a/integration/context_test.go b/integration/context_test.go index 15c15785..9d13f7ac 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) } // Send multiple generate requests with prior context and ensure the response is coherant and expected -func TestGenerateWithHistory(t *testing.T) { +func TestParallelGenerateWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := GenerateRequests() numParallel := 2 @@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) { wg.Wait() } +// Send generate requests with prior context and ensure the response is coherant and expected +func TestGenerateWithHistory(t *testing.T) { + req := api.GenerateRequest{ + Model: smol, + Prompt: rainbowPrompt, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + req.Prompt = rainbowFollowups[i] + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + } +} + // Send multiple chat requests with prior context and ensure the response is coherant and expected -func TestChatWithHistory(t *testing.T) { +func TestParallelChatWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := ChatRequests() numParallel := 2 @@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) { } wg.Wait() } + +// Send generate requests with prior context and ensure the response is coherant and expected +func TestChatWithHistory(t *testing.T) { + req := api.ChatRequest{ + Model: smol, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + Messages: []api.Message{ + { + Role: "user", + Content: rainbowPrompt, + }, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Messages = append(req.Messages, + *assistant, + api.Message{Role: "user", Content: rainbowFollowups[i]}, + ) + + assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + if assistant == nil { + t.Fatalf("didn't get an assistant response for context") + } + } +} diff --git a/integration/library_models_test.go b/integration/library_models_test.go index cdf65efc..49e1097b 100644 --- a/integration/library_models_test.go +++ b/integration/library_models_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "fmt" "log/slog" + "os" "testing" "time" @@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") chatModels := libraryChatModels for _, model := range chatModels { @@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) { if err := PullIfMissing(ctx, client, model); err != nil { t.Fatalf("pull failed %s", err) } + if targetArch != "" { + resp, err := client.Show(ctx, &api.ShowRequest{Name: model}) + if err != nil { + t.Fatalf("unable to show model: %s", err) + } + arch := resp.ModelInfo["general.architecture"].(string) + if arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } + } req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0.1, "seed": 123, }, } - anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"} + anyResp := blueSkyExpected // Special cases if model == "duckdb-nsql" { anyResp = []string{"select", "from"} diff --git a/integration/model_arch_test.go b/integration/model_arch_test.go index 9fc2e01d..721d95c5 100644 --- a/integration/model_arch_test.go +++ b/integration/model_arch_test.go @@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) { // TODO - fiddle with context size req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"} - DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second) + DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second) }) } } diff --git a/integration/model_perf_test.go b/integration/model_perf_test.go index 759e8b9a..3d6ba923 100644 --- a/integration/model_perf_test.go +++ b/integration/model_perf_test.go @@ -40,6 +40,18 @@ var ( // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv func TestModelsPerf(t *testing.T) { + if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { + doModelPerfTest(t, ollamaEngineChatModels) + } else { + doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...)) + } +} + +func TestLibraryModelsPerf(t *testing.T) { + doModelPerfTest(t, libraryChatModels) +} + +func doModelPerfTest(t *testing.T, chatModels []string) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) @@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) { } longPrompt := "summarize the following: " + string(data) - var chatModels []string - if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { - chatModels = ollamaEngineChatModels - } else { - chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...) - } + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") for _, model := range chatModels { + if !strings.Contains(model, ":") { + model = model + ":latest" + } t.Run(model, func(t *testing.T) { if time.Now().Sub(started) > softTimeout { t.Skip("skipping remaining tests to avoid excessive runtime") @@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) { } arch := resp.ModelInfo["general.architecture"].(string) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) + if targetArch != "" && arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } if maxVram > 0 { resp, err := client.List(ctx) @@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) { prompt string anyResp []string }{ - {"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, - {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, + {blueSkyPrompt, blueSkyExpected}, + {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}}, } var gpuPercent int for _, tc := range testCases { @@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) { } } } + // Round the logged prompt count for comparisons across versions/configurations which can vary slightly fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", "MODEL", "CONTEXT", "GPU PERCENT", - "PROMPT COUNT", + "APPROX PROMPT COUNT", "LOAD TIME", "PROMPT EVAL TPS", "EVAL TPS", @@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) { model, numCtx, gpuPercent, - resp.PromptEvalCount, + (resp.PromptEvalCount/10)*10, float64(resp.LoadDuration)/1000000000.0, float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), diff --git a/integration/quantization_test.go b/integration/quantization_test.go index af9da0b6..30564749 100644 --- a/integration/quantization_test.go +++ b/integration/quantization_test.go @@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) { stream := true genReq := api.GenerateRequest{ Model: newName, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 3 * time.Second}, Options: map[string]any{ "seed": 42, @@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) { // Some smaller quantizations can cause models to have poor quality // or get stuck in repetition loops, so we stop as soon as we have any matches - anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"} reqCtx, reqCancel := context.WithCancel(ctx) atLeastOne := false var buf bytes.Buffer genfn := func(response api.GenerateResponse) error { buf.Write([]byte(response.Response)) fullResp := strings.ToLower(buf.String()) - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(fullResp, resp) { atLeastOne = true t.Log(fullResp) diff --git a/integration/utils_test.go b/integration/utils_test.go index 7901fed3..f8ec13f3 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -256,13 +256,29 @@ var ( "snowflake-arctic-embed", "snowflake-arctic-embed2", } + + blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply" + blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"} + + rainbowPrompt = "how do rainbows form? Be brief but factual in your reply" + rainbowFollowups = []string{ + "Explain the physics involved in them. Be breif in your reply", + "Explain the chemistry involved in them. Be breif in your reply", + "Explain the quantum mechanics involved in them. Be breif in your reply", + "What are common myths related to them? Be brief in your reply", + "What are common fairytales related to them? Be brief in your reply", + "Can they form if there is no rain? Be breif in your reply", + "Can they form if there are no clouds? Be breif in your reply", + "Do they happen on other planets? Be brief in your reply", + } + rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"} ) func init() { lifecycle.InitLogging() - custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL") + custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL") if custom != "" { - slog.Info("setting smol test model to " + custom) + slog.Info("setting default test model to " + custom) smol = custom } } @@ -577,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { }, }, [][]string{ - {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, - {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, - {"water", "droplet", "refracted", "reflect", "color", "spectrum"}, + {"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"}, + {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"}, + {"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"}, {"fourth", "july", "declaration", "independence"}, - {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"}, + {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"}, } } From 64883e3c4c0238dc70fddcc456af569d1489415d Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 22 Sep 2025 23:20:20 -0700 Subject: [PATCH 05/16] auth: fix problems with the ollama keypairs (#12373) * auth: fix problems with the ollama keypairs This change adds several fixes including: - reading in the pubkey files correctly - fixing the push unit test to create a keypair file in a temp directory - not return 500 errors for normal status error --- api/client.go | 24 ++++++++---- api/types.go | 2 +- auth/auth.go | 40 ++------------------ cmd/cmd.go | 56 +++++++++++++--------------- cmd/cmd_test.go | 3 ++ server/routes.go | 96 +++++++++++++++++++++++++++++++++++------------- 6 files changed, 119 insertions(+), 102 deletions(-) diff --git a/api/client.go b/api/client.go index 20e6d795..0d4c97ba 100644 --- a/api/client.go +++ b/api/client.go @@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error { return nil } + if resp.StatusCode == http.StatusUnauthorized { + authError := AuthorizationError{StatusCode: resp.StatusCode} + json.Unmarshal(body, &authError) + return authError + } + apiError := StatusError{StatusCode: resp.StatusCode} err := json.Unmarshal(body, &apiError) @@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f scanner.Buffer(scanBuf, maxBufferSize) for scanner.Scan() { var errorResponse struct { - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty"` + SigninURL string `json:"signin_url,omitempty"` } bts := scanner.Bytes() @@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f } if response.StatusCode == http.StatusUnauthorized { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } return AuthorizationError{ StatusCode: response.StatusCode, Status: response.Status, - PublicKey: pubKey, + SigninURL: errorResponse.SigninURL, } } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ @@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } -// Signout will disconnect an ollama instance from ollama.com -func (c *Client) Signout(ctx context.Context, encodedKey string) error { +// Signout will signout a client for a local ollama server. +func (c *Client) Signout(ctx context.Context) error { + return c.do(ctx, http.MethodPost, "/api/signout", nil, nil) +} + +// Disconnect will disconnect an ollama instance from ollama.com. +func (c *Client) Disconnect(ctx context.Context, encodedKey string) error { return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) } diff --git a/api/types.go b/api/types.go index 5b8e034c..8cc7752c 100644 --- a/api/types.go +++ b/api/types.go @@ -41,7 +41,7 @@ func (e StatusError) Error() string { type AuthorizationError struct { StatusCode int Status string - PublicKey string `json:"public_key"` + SigninURL string `json:"signin_url"` } func (e AuthorizationError) Error() string { diff --git a/auth/auth.go b/auth/auth.go index b26e2315..f820964e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -18,46 +18,13 @@ import ( const defaultPrivateKey = "id_ed25519" -func keyPath() (string, error) { - fileIsReadable := func(fp string) bool { - info, err := os.Stat(fp) - if err != nil { - return false - } - - // Check that it's a regular file, not a directory or other file type - if !info.Mode().IsRegular() { - return false - } - - // Try to open it to check readability - file, err := os.Open(fp) - if err != nil { - return false - } - file.Close() - return true - } - - systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) - if fileIsReadable(systemPath) { - return systemPath, nil - } - +func GetPublicKey() (string, error) { home, err := os.UserHomeDir() if err != nil { return "", err } - return filepath.Join(home, ".ollama", defaultPrivateKey), nil -} - -func GetPublicKey() (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err - } - + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) @@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - keyPath, err := keyPath() + home, err := os.UserHomeDir() if err != nil { return "", err } + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) diff --git a/cmd/cmd.go b/cmd/cmd.go index 294e1662..e8cfa134 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,7 +5,6 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -15,7 +14,6 @@ import ( "math" "net" "net/http" - "net/url" "os" "os/signal" "path/filepath" @@ -37,7 +35,6 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -50,7 +47,7 @@ import ( "github.com/ollama/ollama/version" ) -const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" +const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { @@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { if err := loadOrUnloadModel(cmd, &opts); err != nil { var sErr api.AuthorizationError if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - // the server and the client both have the same public key - if pubKey == sErr.PublicKey { - h, _ := os.Hostname() - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") - fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + + if sErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, sErr.SigninURL) } return nil } @@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error { user, err := client.Whoami(cmd.Context()) if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to Ollama to run Cloud models.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } return err } @@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error { return nil } - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - - h, _ := os.Hostname() - fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) - return nil } func SignoutHandler(cmd *cobra.Command, args []string) error { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - client, err := api.ClientFromEnvironment() if err != nil { return err } - err = client.Signout(cmd.Context(), encKey) + err = client.Signout(cmd.Context()) if err != nil { - return err + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You are not signed in to ollama.com") + fmt.Println() + return nil + } else { + return err + } } + fmt.Println("You have signed out of ollama.com") fmt.Println() return nil diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index bb793572..24d28705 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -525,6 +525,9 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) initializeKeypair() cmd := &cobra.Command{} diff --git a/server/routes.go b/server/routes.go index a2078ec1..21a1b2b3 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,6 +4,7 @@ import ( "bytes" "cmp" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -48,6 +49,8 @@ import ( "github.com/ollama/ollama/version" ) +const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" + func shouldUseHarmony(model *Model) bool { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { // heuristic to check whether the template expects to be parsed via harmony: @@ -150,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return runner.llama, model, &opts, nil } +func signinURL() (string, error) { + pubKey, err := auth.GetPublicKey() + if err != nil { + return "", err + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + h, _ := os.Hostname() + return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil +} + func (s *Server) GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest @@ -250,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { client := api.NewClient(remoteURL, http.DefaultClient) err = client.Generate(c, &req, fn) if err != nil { - var sErr api.AuthorizationError - if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pk, pkErr := auth.GetPublicKey() - if pkErr != nil { - slog.Error("couldn't get public key", "error", pkErr) - c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) return } - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "unauthorized", - "public_key": pk, - }) + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1412,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) - r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) r.POST("/api/me", s.WhoamiHandler) + r.POST("/api/signout", s.SignoutHandler) + // deprecated + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1625,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) { if err != nil { slog.Error(err.Error()) } + + // user isn't signed in + if user != nil && user.Name == "" { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + c.JSON(http.StatusOK, user) } func (s *Server) SignoutHandler(c *gin.Context) { - encodedKey := c.Param("encodedKey") + pubKey, err := auth.GetPublicKey() + if err != nil { + slog.Error("couldn't get public key", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) // todo allow other hosts u, err := url.Parse("https://ollama.com") @@ -1640,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) { } client := api.NewClient(u, http.DefaultClient) - err = client.Signout(c, encodedKey) + err = client.Disconnect(c, encKey) if err != nil { - slog.Error(err.Error()) - if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { - c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) @@ -1802,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) { client := api.NewClient(remoteURL, http.DefaultClient) err = client.Chat(c, &req, fn) if err != nil { - var sErr api.AuthorizationError - if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pk, pkErr := auth.GetPublicKey() - if pkErr != nil { - slog.Error("couldn't get public key", "error", pkErr) - c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) return } - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "unauthorized", - "public_key": pk, - }) + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) From a40d427bcea52ad5c7e93780564fc15e5ef80473 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 13:21:47 -0700 Subject: [PATCH 06/16] multi-regexp pretokenizer (#12325) --- model/bytepairencoding.go | 54 ++++++++++++++++++++++++++++------ model/bytepairencoding_test.go | 40 ++++++++++++++++++++++++- model/models/gptoss/model.go | 20 ++++++------- model/models/llama/model.go | 28 +++++++++++++++--- model/models/llama4/model.go | 3 +- model/models/mistral3/model.go | 2 +- model/models/mllama/model.go | 2 +- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model.go | 2 +- model/models/qwen3/embed.go | 2 +- model/models/qwen3/model.go | 2 +- sample/samplers_test.go | 1 - 12 files changed, 124 insertions(+), 34 deletions(-) diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index e21564aa..3d51f70e 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -5,6 +5,7 @@ import ( "fmt" "iter" "log/slog" + "slices" "strings" "github.com/dlclark/regexp2" @@ -13,16 +14,28 @@ import ( ) type BytePairEncoding struct { - pre *regexp2.Regexp - vocab *Vocabulary + vocab *Vocabulary + regexps []*regexp2.Regexp } var _ TextProcessor = (*BytePairEncoding)(nil) -func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { +func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { + if len(pretokenizers) == 0 { + // set default byte-level pretokenizer if none provided, e.g. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 + pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} + } + return BytePairEncoding{ - pre: regexp2.MustCompile(pre, regexp2.None), vocab: vocab, + regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { + for _, p := range pretokenizers { + if !yield(regexp2.MustCompile(p, regexp2.RE2)) { + return + } + } + }), } } @@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool { } func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { - if !yield(m.String()) { - break + parts := []string{s} + for _, re := range bpe.regexps { + parts = slices.Collect(func(yield func(string) bool) { + for _, part := range parts { + r := []rune(part) + var offset int + for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { + if offset-m.Index != 0 { + if !yield(string(r[:m.Index])) { + return + } + } + + if !yield(m.String()) { + return + } + + offset = m.Index + m.Length + } + + if offset < len(r) { + if !yield(string(r[offset:])) { + return + } + } } - } + }) } + + return slices.Values(parts) } // fragment is a string fragment and their corresponding token IDs diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 71947be9..39e5ab45 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding { } return NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &Vocabulary{ Values: tokens, Types: types, Merges: merges, }, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } @@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) } } + +func TestSplit(t *testing.T) { + cases := []struct { + name string + patterns, + want []string + }{ + { + name: "default", + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, + }, + { + name: "unicode", + patterns: []string{ + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, + }, + { + name: "individual digits", + patterns: []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer := NewBytePairEncoding(nil, tt.patterns...) + if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + } +} diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 8456ea5f..6a327065 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) { m := Transformer{ TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - strings.Join([]string{ - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `\p{N}{1,3}`, - ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, - `\s*[\r\n]+`, - `\s+(?!\S)`, - `\s+`, - }, "|"), - ), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + strings.Join([]string{ + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `\p{N}{1,3}`, + ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, + `\s*[\r\n]+`, + `\s+(?!\S)`, + `\s+`, + }, "|"), ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index f6ec0227..c03f04a0 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -54,10 +54,30 @@ func New(c fs.Config) (model.Model, error) { } switch c.String("tokenizer.ggml.model") { case "gpt2": - processor = model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, - &vocabulary, - ) + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + // no-op use the default bpe pretokenizer + case "qwen2": + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + case "refact": + pretokenizers = []string{ + `\p{N}`, + `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`, + } + case "tekken": + pretokenizers = []string{ + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + default: + // use a llama-style pretokenizer + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + } + processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...) case "llama": processor = model.NewSentencePiece(&vocabulary) default: diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 9cb2efc8..e80fbaed 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 435b1a30..5c46615e 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: newTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 239d999d..76974369 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -33,7 +33,6 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 5a345837..2e234710 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ Layers: make([]DecoderLayer, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 6c76305d..6898e38c 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: NewTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go index 9a77efea..c03888d4 100644 --- a/model/models/qwen3/embed.go +++ b/model/models/qwen3/embed.go @@ -35,7 +35,6 @@ func newEmbed(c fs.Config) (model.Model, error) { } m := embedModel{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +47,7 @@ func newEmbed(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Model: &Model{ Layers: layers, diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 35226834..cc58e4a2 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -200,7 +200,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -213,6 +212,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Layers: layers, Options: &Options{ diff --git a/sample/samplers_test.go b/sample/samplers_test.go index b720f027..eb10295d 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding { merges := make([]string, 0, 1) // Only need vocab for Grammar Test return model.NewBytePairEncoding( - ``, &model.Vocabulary{ Values: tokens, Types: make([]int32, len(vocab)), From bf78ed6ee94e593a7edae2e277a736379cbc2413 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 16:08:57 -0700 Subject: [PATCH 07/16] add pre:, suf: to tags (#12274) --- model/model.go | 67 ++++++++++++++++++++----------- model/model_test.go | 61 +++++++++++++++++++++++++--- model/models/llama4/model_text.go | 14 +------ 3 files changed, 101 insertions(+), 41 deletions(-) diff --git a/model/model.go b/model/model.go index f3d6bb3d..2b6ad731 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,7 @@ import ( "fmt" _ "image/jpeg" _ "image/png" + "log/slog" "os" "reflect" "strconv" @@ -171,35 +172,42 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { // make a copy tagsCopy := tags if tag := t.Field(i).Tag.Get("gguf"); tag != "" { - tagsCopy = append(tagsCopy, ParseTags(tag)) + tagsCopy = append(tagsCopy, parseTag(tag)) } if tt == reflect.TypeOf((*Base)(nil)).Elem() { vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { - var fn func([]Tag) [][]string - fn = func(tags []Tag) (names [][]string) { + var fn func([]Tag, string, string) [][]string + fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { if len(tags) > 0 { - localNames := []string{tags[0].Name} - localNames = append(localNames, tags[0].Alternate...) + var names []string + if tags[0].name != "" { + for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { + names = append(names, prefix+n+suffix) + } + } - for _, localName := range localNames { - fullName := []string{localName} - nested := fn(tags[1:]) - if len(nested) > 0 { - for _, rest := range nested { - names = append(names, append(fullName, rest...)) + if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 { + // no child names, append current names + fullNames = append(fullNames, names) + } else if len(names) == 0 { + // no current names, append child names + fullNames = append(fullNames, childNames...) + } else { + // combine current and child names + for _, name := range names { + for _, childName := range childNames { + fullNames = append(fullNames, append([]string{name}, childName...)) } - } else { - names = append(names, fullName) } } } - return names + return fullNames } - names := fn(tagsCopy) + names := fn(tagsCopy, "", "") for _, name := range names { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { logutil.Trace("found tensor", "", tensor) @@ -213,9 +221,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) } } } @@ -254,18 +262,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { } type Tag struct { - Name string - Alternate []string + name, + // prefix and suffix are applied to child tags + prefix, + suffix string + alternatives []string } -func ParseTags(s string) (tag Tag) { +func parseTag(s string) (tag Tag) { parts := strings.Split(s, ",") if len(parts) > 0 { - tag.Name = parts[0] + tag.name = parts[0] for _, part := range parts[1:] { - if value, ok := strings.CutPrefix(part, "alt:"); ok { - tag.Alternate = append(tag.Alternate, value) + if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { + // elevate alternative to primary if no primary given + tag.name = value + slog.Warn("gguf tag has alt: but no primary name", "tag", s) + } else if ok { + tag.alternatives = append(tag.alternatives, value) + } + if value, ok := strings.CutPrefix(part, "pre:"); ok { + tag.prefix = value + } + if value, ok := strings.CutPrefix(part, "suf:"); ok { + tag.suffix = value } } } diff --git a/model/model_test.go b/model/model_test.go index 01080ffd..e4727854 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) { { value: "output", want: Tag{ - Name: "output", + name: "output", }, }, { value: "output,alt:token_embd", want: Tag{ - Name: "output", - Alternate: []string{ + name: "output", + alternatives: []string{ "token_embd", }, }, @@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) { for _, tt := range cases { t.Run(tt.value, func(t *testing.T) { - got := ParseTags(tt.value) - if diff := cmp.Diff(tt.want, got); diff != "" { + got := parseTag(tt.value) + if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" { t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) } }) @@ -147,6 +147,57 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } +func TestPopulateFieldsPrefixSuffixName(t *testing.T) { + type fakeBlock struct { + A *nn.Linear `gguf:"a"` + B *nn.Linear `gguf:",pre:b_"` + C *nn.Linear `gguf:",suf:_c"` + XY *nn.Linear `gguf:",pre:x_,suf:_y"` + } + + type fakeModel struct { + Blocks []fakeBlock `gguf:"blk"` + } + + m := fakeModel{ + Blocks: make([]fakeBlock, 2), + } + v := reflect.ValueOf(&m) + v.Elem().Set(populateFields(Base{b: &fakeBackend{ + names: []string{ + "blk.0.a.weight", + "blk.0.b_weight", + "blk.0.b_bias", + "blk.0.weight_c", + "blk.0.x_weight_y", + "blk.1.a.weight", + "blk.1.b_weight", + "blk.1.b_bias", + "blk.1.weight_c", + "blk.1.x_weight_y", + }, + }}, v.Elem())) + + if diff := cmp.Diff(fakeModel{ + Blocks: []fakeBlock{ + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}}, + }, + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}}, + }, + }, + }, m); diff != "" { + t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) + } +} + func TestModelForArch(t *testing.T) { type fakeModel struct { Model diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index e0f93260..e056391f 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens return nextStates } -// TextSharedExpert is TextMLP with different tensor names -type TextSharedExpert struct { - Gate *nn.Linear `gguf:"ffn_gate_shexp"` - Up *nn.Linear `gguf:"ffn_up_shexp"` - Down *nn.Linear `gguf:"ffn_down_shexp"` -} - -func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) - return mlp.Down.Forward(ctx, hiddenStates) -} - type TextMOE struct { Router *nn.Linear `gguf:"ffn_gate_inp"` Experts *TextExperts - SharedExpert *TextSharedExpert + SharedExpert *TextMLP `gguf:",suf:_shexp"` } func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { From e1979c571aff857568c9c35f5994da40568ef15c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 17:50:53 -0700 Subject: [PATCH 08/16] fix: leaf alt name (#12390) a leaf node with an alternative name gets all its alternatives names added into the same branch rather than creating branches themselves --- model/model.go | 16 +++++++++------- model/model_test.go | 3 +++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/model/model.go b/model/model.go index 2b6ad731..0af16da8 100644 --- a/model/model.go +++ b/model/model.go @@ -187,15 +187,17 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { names = append(names, prefix+n+suffix) } } - - if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 { - // no child names, append current names - fullNames = append(fullNames, names) - } else if len(names) == 0 { - // no current names, append child names + childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) + if len(names) == 0 { + // current tag has no name, use child names only fullNames = append(fullNames, childNames...) + } else if len(childNames) == 0 { + // current tag has names but no children, create branches for each name + for _, name := range names { + fullNames = append(fullNames, []string{name}) + } } else { - // combine current and child names + // merge each name with each child for _, name := range names { for _, childName := range childNames { fullNames = append(fullNames, append([]string{name}, childName...)) diff --git a/model/model_test.go b/model/model_test.go index e4727854..f6d75b23 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -125,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Input *nn.Embedding `gguf:"input"` Output *nn.Linear `gguf:"output,alt:input"` Nested *nested `gguf:"nested"` + Tensor ml.Tensor `gguf:"leaf,alt:tensor"` } var m fakeModel @@ -133,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { names: []string{ "input.weight", "nested.b.weight", + "leaf", }, }}, v.Elem())) @@ -142,6 +144,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Nested: &nested{ Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}}, }, + Tensor: &fakeTensor{Name: "leaf"}, }, m); diff != "" { t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } From fd88cd7cb0966a26f41ec41bc012f2c4d725ab98 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Tue, 23 Sep 2025 23:34:55 -0700 Subject: [PATCH 09/16] harmony: don't sanitize built-ins In #11910 we started sanitizing function names, but we accidentally were modifying built-ins like `browser.open` to `browser_open`. This was removing the special prompt rendering for built-ins, but this wasn't immediately apparent since the models seem to be reasonably good at remembering the built-ins even when presented with these slightly renamed version. This fix prevents built-ins from ever being renamed. --- harmony/harmonyparser.go | 4 ++++ harmony/harmonyparser_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index b365b763..da9fe3e9 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -463,6 +463,10 @@ func (h *HarmonyMessageHandler) HasThinkingSupport() bool { func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { harmonyFunctionName := m.deriveName(userFunctionName) + // built-in functions should not be renamed + if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" { + harmonyFunctionName = userFunctionName + } m.userToHarmony[userFunctionName] = harmonyFunctionName m.harmonyToUser[harmonyFunctionName] = userFunctionName return harmonyFunctionName diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index b988a018..e56178c6 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) { {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, + {name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}}, } for i, tt := range tests { From 2e742544bfc5242be4d76c6fee5082c7e41b3df2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Sep 2025 11:21:32 -0700 Subject: [PATCH 10/16] prefer ollama engine for qwen3moe (#12374) --- fs/ggml/ggml.go | 1 + 1 file changed, 1 insertion(+) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 5da902bc..58803f58 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -244,6 +244,7 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3n", "mistral3", "qwen3", + "qwen3moe", "llama4", "mllama", "qwen25vl", From fbd82ba5bb35c42a6b09f5bd50ff1aa0690b9626 Mon Sep 17 00:00:00 2001 From: Grace <88872231+gr4ceG@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:19:47 -0700 Subject: [PATCH 11/16] Grace/deepseek v3 migration (#12385) * init deepseek model file * temp removal of flash attention implementation * shapes and proper, can make a pass * query, key, value have good cosine similarity, but the max diff is a bit high * Attention block is working! ** with eager for now, have not added the mask line * Attention block is working! ** with eager for now, have not added the mask line * working MoE at around 0.95 cosine sim * added cosine similarity function * Starting end to end structure * Trying (and failing) to get rope to work, going to test full thing on tater * running on tater36... just not the right outputs * we have the right values for rope... but its still not working? * chnage Extrapolation Factor to 1 * removed adding residuals twice, removed normalization from shared expert, refactored Norms (Attention, MLP) to be outside the (Attention, MLP) blocks and in the Transformer block instead, add cache setLayer * Temporary modelfiles for cpu * change kpass intermediate step to kv, two layer outputs [0,1] look fine * this calls for 16 chicken nuggets * whoops * cleaning up code * delete stuff we dont need * getting rid of debug statements for llama cpp * working with long contexts * fix long context view error * reverting some changes I made for files that are not apart of pr * Added proper tokenizer for deeepseek3 * clean up model and go test * remove Modelfile * not passing the tests * whoops * how to pass the ci tests * resolving some of the comments * rename * linted and renamed deepseek3 -> deepseek2 * remove name go * addressed changes - main change was adopting qwen3 naming scheme * I cannot with linters * clean up logs * clean up logs --------- Co-authored-by: Grace Guo Co-authored-by: Grace Guo Co-authored-by: graceguo --- model/models/deepseek2/model.go | 324 ++++++++++++++++++++++++++++++++ model/models/models.go | 1 + 2 files changed, 325 insertions(+) create mode 100644 model/models/deepseek2/model.go diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go new file mode 100644 index 00000000..7b88711b --- /dev/null +++ b/model/models/deepseek2/model.go @@ -0,0 +1,324 @@ +package deepseek2 + +// uses deepseek 2 architecture but written based on deepseek 3 model + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + numExpertsUsed int + numExperts int + normTopKProb bool + routedScalingFactor float32 + + kvLoraRank, + qkNopeHeadDim, + qkRopeHeadDim, + kqNopeHeadDim, + qkHeadDim int + qLoraRank int + vHeadDim int + + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength, + originalContextLength int + + eps, + ropeBase, + ropeScale float32 + kqScale float64 +} + +func (o Options) RoPEOptions() []func(*rope.Options) { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + return []func(*rope.Options){ + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + } +} + +type Attention struct { + Q *nn.Linear `gguf:"attn_q"` + + QA *nn.Linear `gguf:"attn_q_a"` + QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"` + QB *nn.Linear `gguf:"attn_q_b"` + + KVA *nn.Linear `gguf:"attn_kv_a_mqa"` + KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` + KVB *nn.Linear `gguf:"attn_kv_b"` + + Output *nn.Linear `gguf:"attn_out,alt:attn_output"` +} + +func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + seqLength := hiddenStates.Dim(1) + + var query ml.Tensor + if opts.qLoraRank == 0 { // nil { + query = attn.Q.Forward(ctx, hiddenStates) + } else { + query = attn.QA.Forward(ctx, hiddenStates) + query = attn.QANorm.Forward(ctx, query, opts.eps) + query = attn.QB.Forward(ctx, query) + } + + query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) + + qPass := query.View(ctx, 0, + opts.qkNopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0), + opts.qkRopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + compressedKV := attn.KVA.Forward(ctx, hiddenStates) + + kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) + kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), + opts.qkRopeHeadDim, compressedKV.Stride(1), + 1, compressedKV.Stride(1), + compressedKV.Dim(1)) + + kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) + kPass = attn.KVB.Forward(ctx, kPass) + + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) + value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0), + opts.vHeadDim, kv.Stride(1), + kv.Dim(1), kv.Stride(2), + kv.Dim(2)).Contiguous(ctx) + + qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + + kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) + + query = qRot.Concat(ctx, qPass, 0) + key := kRot.Concat(ctx, kPass, 0) + + attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) + return attn.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` + SharedExpert *dense `gguf:",suf:_shexp"` + ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = hiddenStates.SILU(ctx, upStates) + + experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) + experts = experts.Mul(ctx, topKWeights) + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + return nextStates +} + +func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { + scores = scores.Add(ctx, moe.ExpProbsBias) + topKIndices := scores.TopK(ctx, opts.numExpertsUsed) + return topKIndices +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + residuals := hiddenStates + + routerLogits := moe.Router.Forward(ctx, hiddenStates) + scores := routerLogits.Sigmoid(ctx) + topKIndices := moe.topKIndices(ctx, scores, opts) + topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices) + + if opts.normTopKProb { + topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx)) + topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor)) + hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts) + sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts) + + hiddenStates = hiddenStates.Add(ctx, sharedExpertResult) + return hiddenStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP MLP +} + +func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + return hiddenStates +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + + firstDenseLayerIndex := int(c.Uint("leading_dense_block_count")) + for i := range layers { + if i < firstDenseLayerIndex { + layers[i].MLP = &dense{} + } else { + layers[i].MLP = &sparse{} + } + } + + mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor")))) + kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length"))) + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + // Split regex into multiple parts (according to DeepSeek3's regex) + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("expert_weights_norm", true), + + qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal, + kvLoraRank: int(c.Uint("attention.kv_lora_rank")), + qkHeadDim: int(c.Uint("attention.key_length")), + vHeadDim: int(c.Uint("attention.value_length")), + qkRopeHeadDim: int(c.Uint("rope.dimension_count")), + qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + + routedScalingFactor: c.Float("expert_weights_scale"), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + + kqScale: kqScale, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("deepseek2", New) +} diff --git a/model/models/models.go b/model/models/models.go index cc998078..0cda615a 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -2,6 +2,7 @@ package models import ( _ "github.com/ollama/ollama/model/models/bert" + _ "github.com/ollama/ollama/model/models/deepseek2" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" From 2fba04b5fb4a56b1e3536d83383d7853441adea9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 25 Sep 2025 15:37:39 -0600 Subject: [PATCH 12/16] tools: handle the case where a tool call sends "arguments" or "parameters" as a serialized json string (#12413) --- tools/tools.go | 12 ++++++++++++ tools/tools_test.go | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/tools/tools.go b/tools/tools.go index 80fc6e0d..f9a2d3b9 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -273,9 +273,21 @@ func findArguments(buffer []byte) (map[string]any, int) { if args, ok := obj["arguments"].(map[string]any); ok { return args, true } + if argsStr, ok := obj["arguments"].(string); ok { + var argsData map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { + return argsData, ok + } + } if args, ok := obj["parameters"].(map[string]any); ok { return args, true } + if argsStr, ok := obj["parameters"].(string); ok { + var argsData map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { + return argsData, ok + } + } return nil, true } diff --git a/tools/tools_test.go b/tools/tools_test.go index 2a449a0e..288fa73c 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -1274,6 +1274,22 @@ func TestFindArguments(t *testing.T) { "items": []any{"{", "}", map[string]any{"key": "value"}}, }, }, + { + name: "stringified arguments", + buffer: []byte(`{"name": "get_temperature", "arguments": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "stringified parameters", + buffer: []byte(`{"name": "get_temperature", "parameters": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, } for _, tt := range tests { From 5a56ff3cf0b0d3ae4fadd1962269a90a7a6ffb73 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 25 Sep 2025 15:04:43 -0700 Subject: [PATCH 13/16] cli: add device signin flow when doing ollama push (#12405) --- cmd/cmd.go | 20 +++++++++++++++++++- cmd/cmd_test.go | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index e8cfa134..3b41c71e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -540,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error { return err } + n := model.ParseName(args[0]) + if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { + _, err := client.Whoami(cmd.Context()) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to push models to ollama.com.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } + + return err + } + } + p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -576,7 +595,6 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} - n := model.ParseName(args[0]) if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 24d28705..fb3b039e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -491,9 +491,35 @@ func TestPushHandler(t *testing.T) { w.(http.Flusher).Flush() } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", }, + { + name: "not signed in push", + modelName: "notsignedin-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + "signin_url": "https://somethingsomething", + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedOutput: "You need to be signed in to push", + }, { name: "unauthorized push", modelName: "unauthorized-model", @@ -508,6 +534,11 @@ func TestPushHandler(t *testing.T) { t.Fatal(err) } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", }, @@ -564,7 +595,7 @@ func TestPushHandler(t *testing.T) { t.Errorf("expected no error, got %v", err) } if tt.expectedOutput != "" { - if got := string(stdout); got != tt.expectedOutput { + if got := string(stdout); !strings.Contains(got, tt.expectedOutput) { t.Errorf("expected output %q, got %q", tt.expectedOutput, got) } } From 05ba4ca1f4b356df50ed6eede0e2bcdc76b31fb8 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 25 Sep 2025 15:47:46 -0700 Subject: [PATCH 14/16] parsers: fix unicode handling for qwen3-coder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When trimming whitespace at the end of every chunk, we were iterating backwards over the string byte-by-byte instead of rune-by-rune. As an example of how this can cause corruption, suppose we have the multi-byte character ✅ (`"\u2705"`), which is represented in utf-8 as the three bytes `0xE2 0x9C 0x85`. It happens that `0x85` is NEL, which passes `unicode.IsSpace()`. Because we were iterating byte-by-byte, this caused us to mistakenly slice in the middle of the rune, removing `0x85` and leaving `0xE2 0x9C`, which beyond being the incorrect place to slice, is not even a valid utf-8 character. `trailingWhitespaceLen()` was modified to count from the end in a rune-aware way. Tests with various multibyte unicode characters were also added. Fixes: #12414 --- model/parsers/qwen3coder.go | 18 ++- model/parsers/qwen3coder_test.go | 217 +++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+), 4 deletions(-) diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index 0cff1ec1..f44d7c8e 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "unicode" + "unicode/utf8" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" @@ -204,12 +205,21 @@ func overlap(s, delim string) int { } func trailingWhitespaceLen(s string) int { - for i := len(s) - 1; i >= 0; i-- { - if !unicode.IsSpace(rune(s[i])) { - return len(s) - i - 1 + remaining := s + total := 0 + for len(remaining) > 0 { + r, size := utf8.DecodeLastRuneInString(remaining) + // if it's an invalid utf8 rune, assume it isn't whitespace + if r == utf8.RuneError && size == 1 { + break } + if !unicode.IsSpace(r) { + break + } + total += size + remaining = remaining[:len(remaining)-size] } - return len(s) + return total } type XMLFunctionCall struct { diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go index 43823e6f..c77fe2d9 100644 --- a/model/parsers/qwen3coder_test.go +++ b/model/parsers/qwen3coder_test.go @@ -166,6 +166,137 @@ func TestQwenParserStreaming(t *testing.T) { }, }, }, + { + desc: "unicode content", + steps: []step{ + { + input: "你好 🌍testمرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "你好 🌍"}, + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "مرحبا"}, + }, + }, + }, + }, + { + desc: "arabic text handling", + steps: []step{ + { + input: "مرحبا بالعالم", + wantEvents: []qwenEvent{qwenEventContent{content: "مرحبا بالعالم"}}, + }, + }, + }, + { + desc: "emoji passthrough", + steps: []step{ + { + input: "✅", + wantEvents: []qwenEvent{qwenEventContent{content: "✅"}}, + }, + }, + }, + { + desc: "emoji after tool call", + steps: []step{ + { + input: "test完成 ✅", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "完成 ✅"}, + }, + }, + }, + }, + { + desc: "unicode streaming with whitespace handling", + steps: []step{ + { + input: "مرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "مرحبا"}, + }, + }, + { + input: " \n", + wantEvents: []qwenEvent{}, + }, + { + input: "世界", + wantEvents: []qwenEvent{ + qwenEventContent{content: " \n世界"}, + }, + }, + }, + }, + { + desc: "non-breaking space withheld across chunks", + steps: []step{ + { + input: "Hello\u00a0", + wantEvents: []qwenEvent{ + qwenEventContent{content: "Hello"}, + }, + }, + { + input: "world", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u00a0world"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{}, + }, + { + input: "def", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool fakeout", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u3000abc"}, + }, + }, + }, + }, + { + desc: "unicode with partial tool tag", + steps: []step{ + { + input: "测试🎯 b and a < b" }, }, }, + { + name: "unicode in function names and parameters", + tools: []api.Tool{}, + rawToolCall: ` + +北京 + + +Hello! 你好! 🌟 مرحبا + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: map[string]any{ + "城市": "北京", + "message": "Hello! 你好! 🌟 مرحبا", + }, + }, + }, + }, } for i, step := range steps { @@ -360,6 +512,42 @@ ls && echo "a > b and a < b" } } +func TestTrailingWhitespaceLenUnicode(t *testing.T) { + cases := []struct { + name string + input string + want int + }{ + { + name: "ascii space", + input: "Hello ", + want: 1, + }, + { + name: "non-breaking space", + input: "Hello\u00a0", + want: 2, + }, + { + name: "ideographic space", + input: "Hello\u3000", + want: 3, + }, + { + name: "multiple runes of whitespace", + input: "Hi\u00a0\u3000", + want: 5, + }, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.input) + if got != tc.want { + t.Errorf("%s: trailingWhitespaceLen(%q) = %d, want %d", tc.name, tc.input, got, tc.want) + } + } +} + func TestQwenToolCallValueParsing(t *testing.T) { cases := []struct { desc string @@ -867,6 +1055,8 @@ func TestTrailingWhitespaceLen(t *testing.T) { {desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, {desc: "only whitespace", s: " \n ", want: 4}, {desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, + {desc: "unicode with trailing space", s: "测试🎯 ", want: 1}, + {desc: "unicode with trailing tab and newline", s: "مرحبا\t\n", want: 2}, } for _, tc := range cases { @@ -876,3 +1066,30 @@ func TestTrailingWhitespaceLen(t *testing.T) { } } } + +func TestOverlapFunction(t *testing.T) { + cases := []struct { + desc string + s string + delim string + want int + }{ + {desc: "no overlap", s: "hello", delim: "", want: 5}, + {desc: "partial overlap", s: "hello", want: 3}, + {desc: "unicode with partial overlap", s: "测试🎯", want: 3}, + {desc: "unicode string with no overlap", s: "مرحبا", delim: "", want: 0}, + {desc: "unicode at boundary", s: "世界<", delim: "", want: 1}, + {desc: "unicode delimiter single rune", s: "hello🔧", delim: "🔧工具", want: len("🔧")}, + {desc: "unicode delimiter multiple runes", s: "hello🔧工", delim: "🔧工具", want: len("🔧工")}, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := overlap(tc.s, tc.delim) + if got != tc.want { + t.Errorf("overlap(%q, %q) = %d, want %d", tc.s, tc.delim, got, tc.want) + } + }) + } +} From b04e46da3ebca69a2b1216b3943d8a463e8b4a14 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 25 Sep 2025 18:30:45 -0700 Subject: [PATCH 15/16] bugfix: restore the current runOptions if loading fails in the CLI (#12402) There are two bugs when using `/load ` for a model that doesn't exist, namely: 1. it will not restore the current model settings if the current model is a thinking model; and 2. it will crash is the current model is a non-thinking model This bug fix saves the current runOptions and then restores them if the model load doesn't happen. It also fixes the crash happening for non-thinking models. --- cmd/cmd.go | 45 +++++++ cmd/cmd_test.go | 284 +++++++++++++++++++++++++++++++++++++++++++++ cmd/interactive.go | 10 +- 3 files changed, 338 insertions(+), 1 deletion(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 3b41c71e..369a27a4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1118,6 +1118,51 @@ type runOptions struct { ShowConnect bool } +func (r runOptions) Copy() runOptions { + var messages []api.Message + if r.Messages != nil { + messages = make([]api.Message, len(r.Messages)) + copy(messages, r.Messages) + } + + var images []api.ImageData + if r.Images != nil { + images = make([]api.ImageData, len(r.Images)) + copy(images, r.Images) + } + + var opts map[string]any + if r.Options != nil { + opts = make(map[string]any, len(r.Options)) + for k, v := range r.Options { + opts[k] = v + } + } + + var think *api.ThinkValue + if r.Think != nil { + cThink := *r.Think + think = &cThink + } + + return runOptions{ + Model: r.Model, + ParentModel: r.ParentModel, + Prompt: r.Prompt, + Messages: messages, + WordWrap: r.WordWrap, + Format: r.Format, + System: r.System, + Images: images, + Options: opts, + MultiModal: r.MultiModal, + KeepAlive: r.KeepAlive, + Think: think, + HideThinking: r.HideThinking, + ShowConnect: r.ShowConnect, + } +} + type displayResponseState struct { lineLength int wordBuffer string diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index fb3b039e..a84272c8 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" "strings" "testing" "time" @@ -953,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) { }) } } + +func TestRunOptions_Copy(t *testing.T) { + // Setup test data + originalKeepAlive := &api.Duration{Duration: 5 * time.Minute} + originalThink := &api.ThinkValue{Value: "test reasoning"} + + original := runOptions{ + Model: "test-model", + ParentModel: "parent-model", + Prompt: "test prompt", + Messages: []api.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi there"}, + }, + WordWrap: true, + Format: "json", + System: "system prompt", + Images: []api.ImageData{ + []byte("image1"), + []byte("image2"), + }, + Options: map[string]any{ + "temperature": 0.7, + "max_tokens": 1000, + "top_p": 0.9, + }, + MultiModal: true, + KeepAlive: originalKeepAlive, + Think: originalThink, + HideThinking: false, + ShowConnect: true, + } + + // Test the copy + copied := original.Copy() + + // Test 1: Verify the copy is not the same instance + if &copied == &original { + t.Error("Copy should return a different instance") + } + + // Test 2: Verify all fields are copied correctly + tests := []struct { + name string + got interface{} + want interface{} + }{ + {"Model", copied.Model, original.Model}, + {"ParentModel", copied.ParentModel, original.ParentModel}, + {"Prompt", copied.Prompt, original.Prompt}, + {"WordWrap", copied.WordWrap, original.WordWrap}, + {"Format", copied.Format, original.Format}, + {"System", copied.System, original.System}, + {"MultiModal", copied.MultiModal, original.MultiModal}, + {"HideThinking", copied.HideThinking, original.HideThinking}, + {"ShowConnect", copied.ShowConnect, original.ShowConnect}, + } + + for _, tt := range tests { + if !reflect.DeepEqual(tt.got, tt.want) { + t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want) + } + } + + // Test 3: Verify Messages slice is deeply copied + if len(copied.Messages) != len(original.Messages) { + t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages)) + } + + if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] { + t.Error("Messages should be different instances") + } + + // Modify original to verify independence + if len(original.Messages) > 0 { + originalContent := original.Messages[0].Content + original.Messages[0].Content = "modified" + if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" { + t.Error("Messages should be independent after copy") + } + // Restore for other tests + original.Messages[0].Content = originalContent + } + + // Test 4: Verify Images slice is deeply copied + if len(copied.Images) != len(original.Images) { + t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images)) + } + + if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] { + t.Error("Images should be different instances") + } + + // Modify original to verify independence + if len(original.Images) > 0 { + originalImage := original.Images[0] + original.Images[0] = []byte("modified") + if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" { + t.Error("Images should be independent after copy") + } + // Restore for other tests + original.Images[0] = originalImage + } + + // Test 5: Verify Options map is deeply copied + if len(copied.Options) != len(original.Options) { + t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options)) + } + + if len(copied.Options) > 0 && &copied.Options == &original.Options { + t.Error("Options map should be different instances") + } + + // Modify original to verify independence + if len(original.Options) > 0 { + originalTemp := original.Options["temperature"] + original.Options["temperature"] = 0.9 + if copied.Options["temperature"] == 0.9 { + t.Error("Options should be independent after copy") + } + // Restore for other tests + original.Options["temperature"] = originalTemp + } + + // Test 6: Verify KeepAlive pointer is copied (shallow copy) + if copied.KeepAlive != original.KeepAlive { + t.Error("KeepAlive pointer should be the same (shallow copy)") + } + + // Test 7: Verify Think pointer creates a new instance + if original.Think != nil && copied.Think == original.Think { + t.Error("Think should be a different instance") + } + + if original.Think != nil && copied.Think != nil { + if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) { + t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value) + } + } + + // Test 8: Test with zero values + zeroOriginal := runOptions{} + zeroCopy := zeroOriginal.Copy() + + if !reflect.DeepEqual(zeroCopy, zeroOriginal) { + fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy) + t.Error("Copy of zero value should equal original zero value") + } +} + +func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) { + // Test with empty slices and maps + original := runOptions{ + Messages: []api.Message{}, + Images: []api.ImageData{}, + Options: map[string]any{}, + } + + copied := original.Copy() + + if copied.Messages == nil { + t.Error("Empty Messages slice should remain empty, not nil") + } + + if copied.Images == nil { + t.Error("Empty Images slice should remain empty, not nil") + } + + if copied.Options == nil { + t.Error("Empty Options map should remain empty, not nil") + } + + if len(copied.Messages) != 0 { + t.Error("Empty Messages slice should remain empty") + } + + if len(copied.Images) != 0 { + t.Error("Empty Images slice should remain empty") + } + + if len(copied.Options) != 0 { + t.Error("Empty Options map should remain empty") + } +} + +func TestRunOptions_Copy_NilPointers(t *testing.T) { + // Test with nil pointers + original := runOptions{ + KeepAlive: nil, + Think: nil, + } + + copied := original.Copy() + + if copied.KeepAlive != nil { + t.Error("Nil KeepAlive should remain nil") + } + + if copied.Think != nil { + t.Error("Nil Think should remain nil") + } +} + +func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) { + tests := []struct { + name string + think *api.ThinkValue + }{ + {"nil Think", nil}, + {"bool true", &api.ThinkValue{Value: true}}, + {"bool false", &api.ThinkValue{Value: false}}, + {"string value", &api.ThinkValue{Value: "reasoning text"}}, + {"int value", &api.ThinkValue{Value: 42}}, + {"nil value", &api.ThinkValue{Value: nil}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original := runOptions{Think: tt.think} + copied := original.Copy() + + if tt.think == nil { + if copied.Think != nil { + t.Error("Nil Think should remain nil") + } + return + } + + if copied.Think == nil { + t.Error("Non-nil Think should not become nil") + return + } + + if copied.Think == original.Think { + t.Error("Think should be a different instance") + } + + if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) { + t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value) + } + }) + } +} + +func TestRunOptions_Copy_Independence(t *testing.T) { + // Test that modifications to original don't affect copy + originalThink := &api.ThinkValue{Value: "original"} + original := runOptions{ + Model: "original-model", + Messages: []api.Message{{Role: "user", Content: "original"}}, + Options: map[string]any{"key": "value"}, + Think: originalThink, + } + + copied := original.Copy() + + // Modify original + original.Model = "modified-model" + if len(original.Messages) > 0 { + original.Messages[0].Content = "modified" + } + original.Options["key"] = "modified" + if original.Think != nil { + original.Think.Value = "modified" + } + + // Verify copy is unchanged + if copied.Model == "modified-model" { + t.Error("Copy Model should not be affected by original modification") + } + + if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" { + t.Error("Copy Messages should not be affected by original modification") + } + + if copied.Options["key"] == "modified" { + t.Error("Copy Options should not be affected by original modification") + } + + if copied.Think != nil && copied.Think.Value == "modified" { + t.Error("Copy Think should not be affected by original modification") + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index e290d84c..cf0aced1 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Println("Usage:\n /load ") continue } + origOpts := opts.Copy() + opts.Model = args[1] opts.Messages = []api.Message{} fmt.Printf("Loading model '%s'\n", opts.Model) opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet) if err != nil { + if strings.Contains(err.Error(), "not found") { + fmt.Printf("Couldn't find model '%s'\n", opts.Model) + opts = origOpts.Copy() + continue + } return err } if err := loadOrUnloadModel(cmd, &opts); err != nil { if strings.Contains(err.Error(), "not found") { - fmt.Printf("error: %v\n", err) + fmt.Printf("Couldn't find model '%s'\n", opts.Model) + opts = origOpts.Copy() continue } if strings.Contains(err.Error(), "does not support thinking") { From c47154c08d59c93a653a2a798885a4a29bff71ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BE=8A=E6=92=85=E6=92=85?= <31302548+Fachep@users.noreply.github.com> Date: Sat, 27 Sep 2025 02:38:47 +0800 Subject: [PATCH 16/16] fix: correct condition for AMDGPU_TARGETS filtering logic (#12412) --- CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 198fcdeb..6757400e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,10 +98,12 @@ check_language(HIP) if(CMAKE_HIP_COMPILER) set(HIP_PLATFORM "amd") - find_package(hip REQUIRED) if(NOT AMDGPU_TARGETS) + find_package(hip REQUIRED) list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$") - elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) + endif() + + if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) endif()