mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-23 23:18:26 +00:00
parsers/renderers: functiongemma (#13521)
This commit is contained in:
287
model/renderers/functiongemma.go
Normal file
287
model/renderers/functiongemma.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type FunctionGemmaRenderer struct{}
|
||||
|
||||
const defaultSystemMessage = "You can do function calling with the following functions:"
|
||||
|
||||
func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("<bos>")
|
||||
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString("<start_of_turn>developer\n")
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(strings.TrimSpace(systemMessage))
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
if systemMessage != "" {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if strings.TrimSpace(systemMessage) != defaultSystemMessage {
|
||||
// Only add default message if user does not provide it
|
||||
sb.WriteString(defaultSystemMessage)
|
||||
}
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
// Track previous message type for tool response handling
|
||||
prevMessageType := ""
|
||||
|
||||
for i, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
// After tool calls, expect tool responses
|
||||
if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
prevMessageType = "tool_call"
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
case "user":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>user\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
|
||||
case "tool":
|
||||
toolName := ""
|
||||
// Find the tool name from the previous assistant's tool call
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 {
|
||||
// Count how many tool messages came before this one
|
||||
toolIdx := 0
|
||||
for k := j + 1; k < i; k++ {
|
||||
if loopMessages[k].Role == "tool" {
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
if toolIdx < len(loopMessages[j].ToolCalls) {
|
||||
toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_call" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
}
|
||||
sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}<end_function_response>")
|
||||
prevMessageType = "tool_response"
|
||||
|
||||
default:
|
||||
sb.WriteString("<start_of_turn>" + message.Role + "\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
fn := tool.Function
|
||||
sb.WriteString("<start_function_declaration>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:<escape>" + fn.Description + "<escape>")
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
// Only include properties:{} if there are actual properties
|
||||
if len(fn.Parameters.Properties) > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("<escape>" + req + "<escape>")
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:<escape>" + strings.ToUpper(fn.Parameters.Type) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_declaration>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
||||
keys := make([]string, 0, len(props))
|
||||
for k := range props {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop := props[name]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{description:<escape>")
|
||||
sb.WriteString(prop.Description)
|
||||
sb.WriteString("<escape>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
sb.WriteString(",type:<escape>" + strings.ToUpper(prop.Type[0]) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := tc.Function.Arguments[key]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_call>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return "<escape>" + v + "<escape>"
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
Reference in New Issue
Block a user