diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index a51819dd..addce4c9 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -3,15 +3,29 @@ package harmony import ( "fmt" "log/slog" + "slices" "strings" "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/template" ) type harmonyParserState int +func ShouldUseHarmony(modelFamily string, template *template.Template) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if template.Contains("<|start|>") && template.Contains("<|end|>") { + return true + } + } + + return false +} + const ( harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_ParsingHeader @@ -75,18 +89,28 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } -func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { - if lastMessage != nil && lastMessage.Role == "assistant" { - // handle prefilling conditions - if lastMessage.Content != "" { - s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") - return - } else if lastMessage.Thinking != "" { - s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") - return - } +func Prefill(lastMessage api.Message) string { + if lastMessage.Role != "assistant" { + return "" + } + + switch { + case strings.TrimSpace(lastMessage.Content) != "": + return "<|start|>assistant<|channel|>final<|message|>" + case strings.TrimSpace(lastMessage.Thinking) != "": + return "<|start|>assistant<|channel|>analysis<|message|>" + default: + return "" + } +} + +// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided +func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) { + if strings.TrimSpace(prefillString) != "" { + s.acc.WriteString(prefillString) + } else { + s.AddImplicitStart() } - s.AddImplicitStart() } func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index b988a018..dcf1af4e 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -3,6 +3,7 @@ package harmony import ( "fmt" "reflect" + "strings" "testing" ) @@ -535,3 +536,202 @@ func TestFunctionConvertAndAdd(t *testing.T) { }) } } + +func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { + t.Run("thinking_then_content_streams", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + type step struct { + in string + wantContent string + wantThinking string + } + steps := []step{ + {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."}, + {in: "<|end|>", wantThinking: ""}, + {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"}, + {in: "<|end|>", wantContent: ""}, + } + for i, s := range steps { + content, thinking, tool := handler.AddContent(s.in, tp) + if tool != "" { + tp.Add(tool) + } + if content != s.wantContent || thinking != s.wantThinking { + t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking) + } + } + }) + + t.Run("content_streams_as_it_arrives", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|start|>assistant<|message|>Hello", + ", world", + "!<|end|>", + } + var got []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + t.Fatalf("unexpected thinking %q", thinking) + } + if content != "" { + got = append(got, content) + } + } + want := []string{"Hello", ", world", "!"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("content pieces mismatch: got %v want %v", got, want) + } + }) + + t.Run("thinking_streams_separately_from_content", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|channel|>analysis<|message|>Thinking...", + "<|end|>", + "<|start|>assistant<|message|>Answer", + "<|end|>", + } + var got []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + got = append(got, thinking) + } + if content != "" { + got = append(got, content) + } + } + want := []string{"Thinking...", "Answer"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("content pieces mismatch: got %v want %v", got, want) + } + }) + + t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|chan", + "nel|>analysis<|mess", + "age|>Deep ", + "thought", + "<|end|>", + "<|start|>assistant<|message|>Done", + "<|end|>", + } + var thinkingPieces []string + var contentPieces []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + thinkingPieces = append(thinkingPieces, thinking) + } + if content != "" { + contentPieces = append(contentPieces, content) + } + } + if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) { + t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want) + } + if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) { + t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want) + } + }) + + t.Run("simple_assistant_after_analysis", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|channel|>analysis<|message|>Think", + "<|end|>", + "<|start|>assistant<|message|>Answer", + "<|end|>", + } + var contentSb, thinkingSb strings.Builder + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if tool != "" { + tp.Add(tool) + } + contentSb.WriteString(content) + thinkingSb.WriteString(thinking) + } + if contentSb.String() != "Answer" { + t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer") + } + if thinkingSb.String() != "Think" { + t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think") + } + }) + + t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", + } + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if content != "" || thinking != "" { + continue + } + if tool != "" { + tp.Add(tool) + } + } + name, args := tp.Drain() + if name == nil || *name != "functions.calculate" { + t.Fatalf("unexpected tool name: %v", name) + } + if got, want := args, "{\"expression\":\"2+2\"}"; got != want { + t.Fatalf("unexpected tool args: got %s want %s", got, want) + } + }) + + t.Run("tool_call_across_chunks", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.CreateToolParser() + inputs := []string{ + "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", + "2\"}", + "<|end|>", + } + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in, tp) + if content != "" || thinking != "" { + continue + } + if tool != "" { + tp.Add(tool) + } + } + name, args := tp.Drain() + if name == nil || *name != "functions.calculate" { + t.Fatalf("unexpected tool name: %v", name) + } + if got, want := args, "{\"expression\":\"2+2\"}"; got != want { + t.Fatalf("unexpected tool args: got %s want %s", got, want) + } + }) +} diff --git a/llm/server.go b/llm/server.go index 664a69fb..e0a652ec 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1347,7 +1347,9 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess + Grammar string // set before sending the request to the subprocess + UseHarmony bool + PrefillString string } // DoneReason represents the reason why a completion response is done @@ -1360,6 +1362,8 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed + // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit + DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1368,19 +1372,23 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" + case DoneReasonTokenRepeatLimit: + return "token_repeat_limit" default: return "" // closed } } type CompletionResponse struct { - Content string `json:"content"` - DoneReason DoneReason `json:"done_reason"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration time.Duration `json:"eval_duration"` + Content string `json:"content"` + Thinking string `json:"thinking"` + ToolCalls []api.ToolCall `json:"tool_calls"` + DoneReason DoneReason `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1498,7 +1506,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } switch { - case strings.TrimSpace(c.Content) == lastToken: + // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future + case strings.TrimSpace(c.Content) == lastToken && c.Content != "": tokenRepeat++ default: lastToken = strings.TrimSpace(c.Content) @@ -1511,16 +1520,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return ctx.Err() } - if c.Content != "" { - fn(CompletionResponse{ - Content: c.Content, - }) - } - if c.Done { fn(c) return nil } + + if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 { + fn(c) + } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index df3ce1d9..a40643ef 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" @@ -781,6 +782,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + if req.UseHarmony { + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + } + if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -863,6 +872,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } + var lastToken string + tokenRepeat := 0 + const tokenRepeatLimit = 30 for { select { @@ -871,8 +883,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { + if strings.TrimSpace(content) == lastToken { + tokenRepeat++ + } + if tokenRepeat == tokenRepeatLimit { + http.Error(w, "token repeat limit reached", http.StatusInternalServerError) + seq.doneReason = llm.DoneReasonTokenRepeatLimit + close(seq.quit) + return + } + lastToken = strings.TrimSpace(content) + + var thinking string + if harmonyMessageHandler != nil { + var toolContent string + content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) + harmonyToolParser.Add(toolContent) + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, + Content: content, + Thinking: thinking, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -881,7 +912,29 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { + var toolCalls []api.ToolCall + if harmonyMessageHandler != nil { + // these tools still need to be transformed to the original function name + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + ToolCalls: toolCalls, Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, diff --git a/server/routes.go b/server/routes.go index e6e4e2c4..73ea5fea 100644 --- a/server/routes.go +++ b/server/routes.go @@ -46,18 +46,6 @@ import ( "github.com/ollama/ollama/version" ) -func shouldUseHarmony(model *Model) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { - return true - } - } - - return false -} - func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -207,13 +195,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(m) && !req.Raw - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw + var functionNameMap *harmony.FunctionNameMap + if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStart() - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + functionNameMap = harmony.NewFunctionNameMap() } // Validate Think value: string values currently only allowed for gptoss models @@ -357,16 +343,19 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + UseHarmony: useHarmony, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, + Thinking: cr.Thinking, + ToolCalls: cr.ToolCalls, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -375,12 +364,22 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } + if res.Done { + res.DoneReason = cr.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) - res.Response = content - res.Thinking = thinking - harmonyToolParser.Add(toolContent) - } else if thinkingState != nil { + for i, tool := range res.ToolCalls { + res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) + } + if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { + ch <- res + } + return + } + if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -391,30 +390,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { - if useHarmony { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - - res.ToolCalls = append(res.ToolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - - res.DoneReason = cr.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { @@ -1616,27 +1591,21 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(m) + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) processedTools := req.Tools + var functionNameMap *harmony.FunctionNameMap + var prefillString string + // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - + prefillString = harmony.Prefill(msgs[len(msgs)-1]) + functionNameMap = harmony.NewFunctionNameMap() // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1689,15 +1658,17 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + UseHarmony: useHarmony, + PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, + Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1713,31 +1684,13 @@ func (s *Server) ChatHandler(c *gin.Context) { } if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) - - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } + for i, tool := range res.Message.ToolCalls { + res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) } - // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } - return } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index b1ede4e3..bcb02088 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "encoding/json" - "net/http" "strings" "testing" "time" @@ -118,7 +117,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "content streams as it arrives", steps: []step{ { - input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, + input: llm.CompletionResponse{Content: "Hello", Done: false}, wantContent: "Hello", }, { @@ -126,7 +125,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { wantContent: ", world", }, { - input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "!", }, }, @@ -135,20 +134,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "thinking streams separately from content", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, + input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, wantThinking: "Thinking...", }, { - input: llm.CompletionResponse{Content: "<|end|>", Done: false}, - // No output expected - just closes the analysis message and resets state to normal + input: llm.CompletionResponse{Content: "Answer", Done: false}, + wantContent: "Answer", }, { - input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, - wantContent: "Answer", // After message end, state is reset to normal - }, - { - input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - // No output expected - just closes the assistant message + input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, }, }, }, @@ -156,24 +150,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "partial tags buffer until complete", steps: []step{ { - input: llm.CompletionResponse{Content: "<|chan", Done: false}, - // No output - partial tag - }, - { - input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, - // No output - still building tags - }, - { - input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, + input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, wantThinking: "Deep ", }, { - input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, + input: llm.CompletionResponse{Thinking: "thought", Done: false}, wantThinking: "thought", }, { - input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - wantContent: "Done", // After message end, state is reset to normal + input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", }, }, }, @@ -181,7 +167,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "simple assistant after analysis", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "Answer", wantThinking: "Think", }, @@ -191,7 +177,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call parsed and returned correctly", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "The weather is sunny", wantToolCalls: []api.ToolCall{ { @@ -210,15 +196,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call with streaming JSON across chunks", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, - // No output yet - incomplete JSON + input: llm.CompletionResponse{Done: false}, }, { - input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, - // Still no output - incomplete JSON - }, - { - input: llm.CompletionResponse{Content: "2\"}", Done: true}, + input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, wantToolCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ @@ -400,9 +381,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { gin.SetMode(gin.TestMode) mockResponses := []llm.CompletionResponse{ - {Content: "<|message|>First ", Done: false}, + {Content: "First ", Done: false}, {Content: "chunk ", Done: false}, - {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + {Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, } mock := mockRunner{ @@ -507,189 +488,3 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) } } - -func TestChatHarmonyParserStreaming(t *testing.T) { - gin.SetMode(gin.TestMode) - - type expectedChunk struct { - afterResponse int // Which mock response this chunk should appear after - content string // Expected content in this chunk - thinking string // Expected thinking in this chunk - } - - testCases := []struct { - name string - mockResponses []llm.CompletionResponse - expectedChunks []expectedChunk - wantContent string - wantThinking string - }{ - { - name: "simple message without thinking", - mockResponses: []llm.CompletionResponse{ - {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, - {Content: "how can I help?", Done: false}, - {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 1, content: "Hello, "}, - {afterResponse: 2, content: "how can I help?"}, - }, - wantContent: "Hello, how can I help?", - }, - { - name: "message with analysis channel for thinking", - mockResponses: []llm.CompletionResponse{ - {Content: "<|channel|>analysis<|message|>", Done: false}, - {Content: "Let me think ", Done: false}, - {Content: "about this problem...", Done: false}, - {Content: "<|end|>", Done: false}, - {Content: "<|start|>assistant<|message|>", Done: false}, - {Content: "The answer ", Done: false}, - {Content: "is 42", Done: false}, - {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 2, thinking: "Let me think "}, - {afterResponse: 3, thinking: "about this problem..."}, - {afterResponse: 6, content: "The answer "}, - {afterResponse: 7, content: "is 42"}, - }, - wantContent: "The answer is 42", - wantThinking: "Let me think about this problem...", - }, - { - name: "streaming with partial tags across boundaries", - mockResponses: []llm.CompletionResponse{ - {Content: "<|chan", Done: false}, - {Content: "nel|>analy", Done: false}, - {Content: "sis<|mess", Done: false}, - {Content: "age|>Think", Done: false}, - {Content: "ing deeply...<|end|>", Done: false}, - {Content: "<|start|>assi", Done: false}, - {Content: "stant<|message|>Result ", Done: false}, - {Content: "computed<|e", Done: false}, - {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 4, thinking: "Think"}, - {afterResponse: 5, thinking: "ing deeply..."}, - {afterResponse: 7, content: "Result "}, - {afterResponse: 8, content: "computed"}, - }, - wantContent: "Result computed", - wantThinking: "Thinking deeply...", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Channel to synchronize mock responses with chunk verification - responsesSent := make(chan int, len(tc.mockResponses)) - - mock := mockRunner{ - CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { - // Send mock responses one at a time, notifying when each is sent - for i, resp := range tc.mockResponses { - fn(resp) - responsesSent <- i + 1 - } - close(responsesSent) - return nil - }, - } - - s := Server{ - sched: &Scheduler{ - pendingReqCh: make(chan *LlmRequest, 1), - finishedReqCh: make(chan *LlmRequest, 1), - expiredCh: make(chan *runnerRef, 1), - unloadedCh: make(chan any, 1), - loaded: make(map[string]*runnerRef), - newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, - reschedDelay: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { - req.successCh <- &runnerRef{ - llama: &mock, - } - return false - }, - }, - } - - go s.sched.Run(t.Context()) - - // Create a minimal model - _, digest := createHarmonyTestModel(t) - - // Create model with passthrough template - stream := false - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "harmony-test", - Files: map[string]string{"file.gguf": digest}, - Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, - Stream: &stream, - }) - - if w.Code != http.StatusOK { - t.Fatalf("failed to create model: %d", w.Code) - } - - // Test chat endpoint with streaming - streamTrue := true - w = createRequest(t, s.ChatHandler, api.ChatRequest{ - Model: "harmony-test", - Messages: []api.Message{{Role: "user", Content: "Hello"}}, - Stream: &streamTrue, - Tools: getTestTools(), - }) - - if w.Code != http.StatusOK { - t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) - } - - // Parse streaming response - var chunks []api.ChatResponse - var content, thinking strings.Builder - - decoder := json.NewDecoder(w.Body) - for decoder.More() { - var chunk api.ChatResponse - if err := decoder.Decode(&chunk); err != nil { - t.Fatalf("failed to decode chunk: %v", err) - } - chunks = append(chunks, chunk) - - // Accumulate content and thinking from each chunk - content.WriteString(chunk.Message.Content) - thinking.WriteString(chunk.Message.Thinking) - - // Debug output - t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) - } - - // Verify we got streaming chunks - if len(chunks) == 0 { - t.Fatal("expected streaming chunks, got none") - } - - gotContent := content.String() - gotThinking := thinking.String() - - if gotContent != tc.wantContent { - t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) - } - if gotThinking != tc.wantThinking { - t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) - } - - // Verify last chunk has done=true - lastChunk := chunks[len(chunks)-1] - if !lastChunk.Done { - t.Error("expected last chunk to have done=true") - } - }) - } -}