diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go index df3b3a45..32611791 100644 --- a/model/renderers/qwen3coder.go +++ b/model/renderers/qwen3coder.go @@ -99,9 +99,7 @@ func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkVa sb.WriteString("\n" + name + "") if len(prop.Type) > 0 { - // TODO(!!!)(drifkin): we should match the reference implementation for - // more complex types here instead of using this format - sb.WriteString("\n" + prop.ToTypeScriptType() + "") + sb.WriteString("\n" + formatToolDefinitionType(prop.Type) + "") } if prop.Description != "" { @@ -215,3 +213,24 @@ func formatToolCallArgument(value any) string { return fmt.Sprintf("%v", value) } + +func formatToolDefinitionType(tp api.PropertyType) string { + if len(tp) == 0 { + return "[]" + } + + if len(tp) == 1 { + return tp[0] + } + + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonBytes, err := json.Marshal(tp) + if err != nil { + return "[]" + } + + return string(jsonBytes) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go index 4aaa066d..6a9e5ecc 100644 --- a/model/renderers/qwen3coder_test.go +++ b/model/renderers/qwen3coder_test.go @@ -336,3 +336,35 @@ func TestFormatToolCallArgument(t *testing.T) { }) } } + +func TestQwen3ToolDefinitionTypes(t *testing.T) { + tests := []struct { + name string + propertyType api.PropertyType + expected string + }{ + { + name: "simple", + propertyType: api.PropertyType{"string"}, + expected: "string", + }, + { + name: "multiple", + propertyType: api.PropertyType{"string", "number"}, + expected: "[\"string\",\"number\"]", + }, + { + name: "empty", + propertyType: api.PropertyType{}, + expected: "[]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolDefinitionType(tt.propertyType) + if got != tt.expected { + t.Errorf("formatToolDefinitionType() = %v, want %v", got, tt.expected) + } + }) + } +}