This commit is contained in:
Michael Yang
2024-06-20 11:00:08 -07:00
parent 269ed6e6a2
commit 2c3fe1fd97
5 changed files with 224 additions and 113 deletions

View File

@@ -83,6 +83,7 @@ type Template struct {
raw string
}
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
@@ -101,28 +102,25 @@ var response = parse.ActionNode{
},
}
var funcs = template.FuncMap{
"toJson": func(v any) string {
b, err := json.Marshal(v)
if err != nil {
return ""
}
return string(b)
},
"add": func(a, b int) int {
return a + b
},
"sub": func(a, b int) int {
return a - b
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{
"toJson": func(v any) string {
b, err := json.Marshal(v)
if err != nil {
return ""
}
return string(b)
},
"isLastMessage": func(s []*api.Message, m *api.Message) bool {
for i := len(s) - 1; i >= 0; i-- {
if m.Role != s[i].Role {
continue
}
return m == s[i]
}
return false
},
})
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s)
if err != nil {
@@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
func collate(msgs []api.Message) (system string, collated []*api.Message) {
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"testing"
"text/template"
@@ -15,6 +16,98 @@ import (
"github.com/ollama/ollama/llm"
)
func TestFuncs(t *testing.T) {
t.Run("toJson", func(t *testing.T) {
cases := []struct {
input any
expected string
}{
{nil, "null"},
{true, "true"},
{false, "false"},
{0, "0"},
{1, "1"},
{1.0, "1"},
{1.1, "1.1"},
{"", `""`},
{"hello", `"hello"`},
{[]int{1, 2, 3}, "[1,2,3]"},
{[]string{"a", "b", "c"}, `["a","b","c"]`},
{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
{map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`},
}
for _, tt := range cases {
t.Run(tt.expected, func(t *testing.T) {
toJson, ok := funcs["toJson"].(func(any) string)
if !ok {
t.Fatal("toJson is not a function")
}
if s := toJson(tt.input); s != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, s)
}
})
}
})
t.Run("add", func(t *testing.T) {
cases := []struct {
a, b int
expected int
}{
{0, 0, 0},
{0, 1, 1},
{1, 0, 1},
{1, 1, 2},
{1, -1, 0},
{-1, 1, 0},
{-1, -1, -2},
}
for _, tt := range cases {
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
add, ok := funcs["add"].(func(int, int) int)
if !ok {
t.Fatal("add is not a function")
}
if n := add(tt.a, tt.b); n != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, n)
}
})
}
})
t.Run("sub", func(t *testing.T) {
cases := []struct {
a, b int
expected int
}{
{0, 0, 0},
{0, 1, -1},
{1, 0, 1},
{1, 1, 0},
{1, -1, 2},
{-1, 1, -2},
{-1, -1, 0},
}
for _, tt := range cases {
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
sub, ok := funcs["sub"].(func(int, int) int)
if !ok {
t.Fatal("sub is not a function")
}
if n := sub(tt.a, tt.b); n != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, n)
}
})
}
})
}
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
@@ -89,77 +182,86 @@ func TestParse(t *testing.T) {
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
templates []string
name string
templates []template
values Values
expected string
}{
{
[]string{
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
`{{- range .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`,
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "Yay!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `,
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
[]string{
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
`
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`,
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "Yay!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
Yay![/INST] `,
What is your name?[/INST] `,
},
{
[]string{
`{{ if .System }}<|im_start|>system
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`,
`
`},
{"messages", `
{{- range .Messages }}
{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ print "\n" }}
{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ print "\n" }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`,
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "Yay!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>user
@@ -169,23 +271,25 @@ Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Yay!<|im_end|>
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
[]string{
`{{ if .Prompt }}Question: {{ .Prompt }}
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`,
`
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `,
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
@@ -211,10 +315,10 @@ Answer: `,
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
for _, tmpl := range tt.templates {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tmpl)
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}