mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge pull request #11910 from ollama/drifkin/harmony-fn-names
harmony: convert fn names to be valid ts identifiers
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -275,8 +276,9 @@ const (
|
|||||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||||
type HarmonyMessageHandler struct {
|
type HarmonyMessageHandler struct {
|
||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
harmonyParser *HarmonyParser
|
harmonyParser *HarmonyParser
|
||||||
|
functionNameMap *FunctionNameMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
@@ -288,6 +290,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
|||||||
MessageEndTag: "<|end|>",
|
MessageEndTag: "<|end|>",
|
||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
|
functionNameMap: NewFunctionNameMap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,3 +381,97 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
|||||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||||
return a.acc.String()
|
return a.acc.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FunctionNameMap maps a user-specified function name to a valid function
|
||||||
|
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||||
|
// transform user-specified function names, which might contain characters that
|
||||||
|
// are not allowed in TypeScript identifiers
|
||||||
|
type FunctionNameMap struct {
|
||||||
|
userToHarmony map[string]string
|
||||||
|
harmonyToUser map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunctionNameMap() *FunctionNameMap {
|
||||||
|
return &FunctionNameMap{
|
||||||
|
userToHarmony: make(map[string]string),
|
||||||
|
harmonyToUser: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||||
|
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||||
|
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||||
|
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||||
|
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||||
|
// the conversion process is not reversible without the appropriate state
|
||||||
|
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||||
|
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||||
|
return userFunctionName
|
||||||
|
}
|
||||||
|
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||||
|
// fallback to the original function name if we can't find a mapping
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToValidChars converts a user-specified function name to a valid
|
||||||
|
// TypeScript identifier.
|
||||||
|
//
|
||||||
|
// Limitations:
|
||||||
|
//
|
||||||
|
// - This doesn't restrict reserved TypeScript keywords.
|
||||||
|
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||||
|
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||||
|
// identifiers these models were trained on, so in the end we might want to
|
||||||
|
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||||
|
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||||
|
mapper := func(r rune) rune {
|
||||||
|
// first, replace certain characters with underscores
|
||||||
|
if r == ' ' || r == '-' || r == '.' {
|
||||||
|
return '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, remove any other characters
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
candidate := strings.Map(mapper, userFunctionName)
|
||||||
|
|
||||||
|
// set a default name if we end up with nothing left
|
||||||
|
if candidate == "" {
|
||||||
|
return "unnamed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the candidate starts with a number, prepend an underscore to make it a
|
||||||
|
// valid identifier
|
||||||
|
if unicode.IsDigit(rune(candidate[0])) {
|
||||||
|
candidate = "_" + candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||||
|
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||||
|
candidate := originalCandidate
|
||||||
|
|
||||||
|
// Check for dupes, and if so, add a number to the end.
|
||||||
|
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||||
|
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||||
|
count := 2
|
||||||
|
for {
|
||||||
|
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|||||||
@@ -467,3 +467,71 @@ func TestHarmonyParserStreaming(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||||
|
// handle any saving (and therefore no dupe handling)
|
||||||
|
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||||
|
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||||
|
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||||
|
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||||
|
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||||
|
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||||
|
{name: "leading number", in: "123", want: "_123"},
|
||||||
|
{name: "$ allowed", in: "$", want: "$"},
|
||||||
|
// show that we allow weird unicode letter characters, though we might want
|
||||||
|
// to convert them to their closest ASCII equivalents in the future
|
||||||
|
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||||
|
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||||
|
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
got := parser.convertToValidChars(tt.in)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||||
|
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||||
|
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||||
|
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||||
|
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
for j, in := range tt.in {
|
||||||
|
got := parser.ConvertAndAdd(in)
|
||||||
|
want := tt.want[j]
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||||
|
}
|
||||||
|
// check that the maps are correct
|
||||||
|
if parser.userToHarmony[in] != want {
|
||||||
|
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||||
|
}
|
||||||
|
if parser.harmonyToUser[want] != in {
|
||||||
|
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1603,7 +1603,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
var harmonyMessageHandler *HarmonyMessageHandler
|
||||||
|
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||||
|
|
||||||
|
useHarmony := shouldUseHarmony(*m)
|
||||||
|
|
||||||
|
processedTools := req.Tools
|
||||||
|
if useHarmony {
|
||||||
|
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||||
|
var lastMessage *api.Message
|
||||||
|
if len(msgs) > 0 {
|
||||||
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
|
}
|
||||||
|
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||||
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
|
||||||
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
|
// renamed to be valid Harmony function names.
|
||||||
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
|
copy(processedTools, req.Tools)
|
||||||
|
for i, tool := range processedTools {
|
||||||
|
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1623,27 +1647,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m)
|
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for gptoss models
|
||||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
|
||||||
|
|
||||||
if useHarmony {
|
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
|
||||||
var lastMessage *api.Message
|
|
||||||
if len(msgs) > 0 {
|
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
|
||||||
}
|
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
}
|
|
||||||
|
|
||||||
var thinkingState *thinking.Parser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||||
@@ -1696,6 +1705,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
toolName, toolContent := harmonyToolParser.Drain()
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
if toolName != nil {
|
if toolName != nil {
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName)
|
||||||
var args api.ToolCallFunctionArguments
|
var args api.ToolCallFunctionArguments
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||||
|
|||||||
Reference in New Issue
Block a user