mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
fix(tokenizer): add special tokens to empty inputs (#13091)
This commit is contained in:
@@ -237,7 +237,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial {
|
||||||
ids = bpe.vocab.addSpecials(ids)
|
ids = bpe.vocab.addSpecials(ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial {
|
||||||
ids = spm.vocab.addSpecials(ids)
|
ids = spm.vocab.addSpecials(ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
|
|||||||
|
|
||||||
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||||
if v.AddBOS && len(v.BOS) > 0 {
|
if v.AddBOS && len(v.BOS) > 0 {
|
||||||
if slices.Contains(v.BOS, ids[0]) {
|
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if v.AddEOS && len(v.EOS) > 0 {
|
if v.AddEOS && len(v.EOS) > 0 {
|
||||||
if slices.Contains(v.BOS, ids[len(ids)-1]) {
|
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSpecialVocabulary(t *testing.T) {
|
||||||
vocab := &Vocabulary{
|
vocab := &Vocabulary{
|
||||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||||
@@ -14,3 +18,90 @@ func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
|||||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddSpecialVocabulary(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
vocab *Vocabulary
|
||||||
|
input []int32
|
||||||
|
want []int32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "add bos",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{0, 2, 3, 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// TODO(mxyng): this is to match previous behaviour
|
||||||
|
name: "add bos when already present",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{0, 2, 3, 4},
|
||||||
|
want: []int32{0, 0, 2, 3, 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add eos",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{2, 3, 4, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// TODO(mxyng): this is to match previous behaviour
|
||||||
|
name: "add eos when already present",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4, 1},
|
||||||
|
want: []int32{2, 3, 4, 1, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add both",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{0, 2, 3, 4, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add bos to empty inputs",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{},
|
||||||
|
want: []int32{0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.vocab.addSpecials(tt.input)
|
||||||
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Errorf("no match (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial {
|
||||||
ids = wpm.vocab.addSpecials(ids)
|
ids = wpm.vocab.addSpecials(ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user