mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
llama: remove model loading for grammar (#10096)
This commit is contained in:
@@ -5,9 +5,9 @@ import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -22,7 +22,7 @@ type Sampler struct {
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
@@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
}
|
||||
}
|
||||
|
||||
type Grammar struct {
|
||||
vocab *Vocab
|
||||
grammar string
|
||||
sampler *llama.Sampler
|
||||
type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
||||
v, err := vocab.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
vocab: vocab,
|
||||
grammar: grammar,
|
||||
sampler: llama.NewGrammarSampler(v, grammar),
|
||||
}, nil
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
return &GrammarSampler{grammar: grammar}, nil
|
||||
}
|
||||
|
||||
func (g *Grammar) Apply(tokens []token) {
|
||||
func (g *GrammarSampler) Apply(tokens []token) {
|
||||
tds := make([]llama.TokenData, len(tokens))
|
||||
for i, token := range tokens {
|
||||
tds[i].Id = token.id
|
||||
tds[i].ID = token.id
|
||||
tds[i].Logit = token.value
|
||||
}
|
||||
|
||||
g.sampler.Apply(tds)
|
||||
g.grammar.Apply(tds)
|
||||
|
||||
for i := range tokens {
|
||||
tokens[i].value = tds[i].Logit
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Grammar) Accept(token int32) {
|
||||
g.sampler.Accept(token)
|
||||
func (g *GrammarSampler) Accept(token int32) {
|
||||
g.grammar.Accept(token)
|
||||
}
|
||||
|
||||
type Vocab struct {
|
||||
once sync.Once
|
||||
vocab *llama.Vocab
|
||||
err error
|
||||
path string
|
||||
}
|
||||
|
||||
func NewVocab(path string) *Vocab {
|
||||
return &Vocab{path: path}
|
||||
}
|
||||
|
||||
// Load returns the lazily-loaded vocabulary
|
||||
func (v *Vocab) Load() (*llama.Vocab, error) {
|
||||
v.once.Do(func() {
|
||||
vocab, err := llama.LoadVocabFromFile(v.path)
|
||||
if err != nil {
|
||||
v.err = err
|
||||
return
|
||||
}
|
||||
v.vocab = vocab
|
||||
})
|
||||
return v.vocab, v.err
|
||||
func (g *GrammarSampler) Free() {
|
||||
g.grammar.Free()
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]uint32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
}
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
``,
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestGrammar(t *testing.T) {
|
||||
tokenizer := modelHelper(t)
|
||||
|
||||
grammarJSON := `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer grammar.Free()
|
||||
|
||||
logits := make([]float32, len(tokenizer.Vocabulary().Values))
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
}
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range tokens {
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
|
||||
grammar.Apply(tokens)
|
||||
nonInfCount := 0
|
||||
infCount := 0
|
||||
for _, tok := range tokens {
|
||||
if math.IsInf(float64(tok.value), -1) {
|
||||
infCount++
|
||||
} else {
|
||||
nonInfCount++
|
||||
}
|
||||
}
|
||||
if nonInfCount == 0 {
|
||||
t.Error("expected at least one non -inf token after grammar application, got none")
|
||||
}
|
||||
if infCount == 0 {
|
||||
t.Error("expected some -inf tokens after grammar application, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
|
||||
Reference in New Issue
Block a user