mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
model: handle multiple eos tokens (#10577)
* get eos_token_id from generation_config.json * refactor * include both ids and strings in trace * comments * remove special case for gemma3 special vocab (#10743)
This commit is contained in:
@@ -53,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
for _, sv := range t.SpecialVocabulary {
|
||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
||||
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
||||
if len(sv.IDs) > 0 {
|
||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
|
||||
@@ -110,6 +110,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
}
|
||||
|
||||
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
@@ -171,6 +172,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
}
|
||||
}
|
||||
|
||||
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
defer f.Close()
|
||||
|
||||
var p map[string]json.RawMessage
|
||||
if err := json.NewDecoder(f).Decode(&p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, st := range specialTokenTypes {
|
||||
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
|
||||
var ids []int32
|
||||
if err := json.Unmarshal(bts, &ids); err != nil {
|
||||
// value is not a list so the existing ID is used
|
||||
continue
|
||||
}
|
||||
|
||||
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
|
||||
return sv.Type == st
|
||||
}); i >= 0 {
|
||||
t.SpecialVocabulary[i].IDs = ids
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -280,6 +309,9 @@ type SpecialVocabulary struct {
|
||||
ID int
|
||||
Content string
|
||||
AddToken bool
|
||||
|
||||
// IDs is populated by generation_config.json
|
||||
IDs []int32
|
||||
}
|
||||
|
||||
func (sv SpecialVocabulary) Key() string {
|
||||
|
||||
@@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generation config eos token ids",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "<bos>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"content": "<eos>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"content": "<eot>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"content": "<eom>",
|
||||
"special": true
|
||||
}
|
||||
],
|
||||
"model": {
|
||||
"vocab": {
|
||||
"<bos>": 0,
|
||||
"<eos>": 1,
|
||||
"<eot>": 2,
|
||||
"<eom>": 3
|
||||
}
|
||||
}
|
||||
}`),
|
||||
"tokenizer_config.json": strings.NewReader(`{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"bos_token": "<bos>",
|
||||
"eos_token": "<eos>"
|
||||
}`),
|
||||
"generation_config.json": strings.NewReader(`{
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": [1, 2, 3]
|
||||
}`),
|
||||
}),
|
||||
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
|
||||
Scores: []float32{0, 1, 2, 3},
|
||||
Types: []int32{3, 3, 3, 3},
|
||||
},
|
||||
SpecialVocabulary: []*SpecialVocabulary{
|
||||
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
|
||||
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
||||
Reference in New Issue
Block a user