Files
ollama-for-amd/model/renderers/functiongemma.go
2025-12-18 07:55:37 -08:00

288 lines
6.8 KiB
Go

package renderers
import (
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/api"
)
type FunctionGemmaRenderer struct{}
const defaultSystemMessage = "You can do function calling with the following functions:"
func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("<bos>")
var systemMessage string
var loopMessages []api.Message
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
systemMessage = messages[0].Content
loopMessages = messages[1:]
} else {
loopMessages = messages
}
if systemMessage != "" || len(tools) > 0 {
sb.WriteString("<start_of_turn>developer\n")
if systemMessage != "" {
sb.WriteString(strings.TrimSpace(systemMessage))
}
if len(tools) > 0 {
if systemMessage != "" {
sb.WriteString("\n")
}
if strings.TrimSpace(systemMessage) != defaultSystemMessage {
// Only add default message if user does not provide it
sb.WriteString(defaultSystemMessage)
}
}
for _, tool := range tools {
sb.WriteString(r.renderToolDeclaration(tool))
}
sb.WriteString("<end_of_turn>\n")
}
// Track previous message type for tool response handling
prevMessageType := ""
for i, message := range loopMessages {
switch message.Role {
case "assistant":
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>model\n")
}
prevMessageType = ""
if message.Content != "" {
sb.WriteString(strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, tc := range message.ToolCalls {
sb.WriteString(r.formatToolCall(tc))
}
// After tool calls, expect tool responses
if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" {
sb.WriteString("<start_function_response>")
prevMessageType = "tool_call"
} else {
sb.WriteString("<end_of_turn>\n")
}
} else {
sb.WriteString("<end_of_turn>\n")
}
case "user":
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>user\n")
}
prevMessageType = ""
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<end_of_turn>\n")
case "tool":
toolName := ""
// Find the tool name from the previous assistant's tool call
for j := i - 1; j >= 0; j-- {
if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 {
// Count how many tool messages came before this one
toolIdx := 0
for k := j + 1; k < i; k++ {
if loopMessages[k].Role == "tool" {
toolIdx++
}
}
if toolIdx < len(loopMessages[j].ToolCalls) {
toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name
}
break
}
}
if prevMessageType != "tool_call" {
sb.WriteString("<start_function_response>")
}
sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}<end_function_response>")
prevMessageType = "tool_response"
default:
sb.WriteString("<start_of_turn>" + message.Role + "\n")
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<end_of_turn>\n")
}
}
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>model\n")
}
return sb.String(), nil
}
func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
var sb strings.Builder
fn := tool.Function
sb.WriteString("<start_function_declaration>declaration:" + fn.Name + "{")
sb.WriteString("description:<escape>" + fn.Description + "<escape>")
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
sb.WriteString(",parameters:{")
needsComma := false
// Only include properties:{} if there are actual properties
if len(fn.Parameters.Properties) > 0 {
sb.WriteString("properties:{")
r.writeProperties(&sb, fn.Parameters.Properties)
sb.WriteString("}")
needsComma = true
}
if len(fn.Parameters.Required) > 0 {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("required:[")
for i, req := range fn.Parameters.Required {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString("<escape>" + req + "<escape>")
}
sb.WriteString("]")
needsComma = true
}
if fn.Parameters.Type != "" {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("type:<escape>" + strings.ToUpper(fn.Parameters.Type) + "<escape>")
}
sb.WriteString("}")
}
sb.WriteString("}<end_function_declaration>")
return sb.String()
}
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
keys := make([]string, 0, len(props))
for k := range props {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, name := range keys {
prop := props[name]
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(name + ":{description:<escape>")
sb.WriteString(prop.Description)
sb.WriteString("<escape>")
if len(prop.Type) > 0 {
sb.WriteString(",type:<escape>" + strings.ToUpper(prop.Type[0]) + "<escape>")
}
sb.WriteString("}")
}
}
func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
var sb strings.Builder
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
value := tc.Function.Arguments[key]
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(value))
}
sb.WriteString("}<end_function_call>")
return sb.String()
}
func (r *FunctionGemmaRenderer) formatArgValue(value any) string {
switch v := value.(type) {
case string:
return "<escape>" + v + "<escape>"
case bool:
if v {
return "true"
}
return "false"
case float64:
if v == float64(int64(v)) {
return fmt.Sprintf("%d", int64(v))
}
return fmt.Sprintf("%v", v)
case int, int64, int32:
return fmt.Sprintf("%d", v)
case map[string]any:
return r.formatMapValue(v)
case []any:
return r.formatArrayValue(v)
default:
return fmt.Sprintf("%v", v)
}
}
func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string {
var sb strings.Builder
sb.WriteString("{")
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
}
sb.WriteString("}")
return sb.String()
}
func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string {
var sb strings.Builder
sb.WriteString("[")
for i, item := range arr {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(r.formatArgValue(item))
}
sb.WriteString("]")
return sb.String()
}