Merge remote-tracking branch 'upstream/main'

This commit is contained in:
likelovewant
2025-03-18 18:09:35 +08:00
9 changed files with 271 additions and 102 deletions

View File

@@ -98,7 +98,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED) find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS) if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011|1012(:xnack-)|103[0-6]|110[0-3]|1150)$") list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|1201)$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif() endif()

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1011;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1012:xnack-;" "AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
} }
} }
], ],

View File

@@ -312,17 +312,19 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name) return fmt.Errorf("unassigned tensor: %s", t.Name)
} }
bts := make([]byte, t.Size()) bts := C.malloc(C.size_t(t.Size()))
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) if bts == nil {
if err != nil { return errors.New("failed to allocate tensor buffer")
return err }
defer C.free(bts)
buf := unsafe.Slice((*byte)(bts), t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
if err != nil || n != len(buf) {
return errors.New("read failed")
} }
if n != len(bts) { C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
return errors.New("short read")
}
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil return nil
}) })
} }
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)), C.int(len(schedBackends)),
C.size_t(maxGraphNodes), C.size_t(maxGraphNodes),
true, C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
), ),
input: deviceBufferTypes[input.d], input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d], output: deviceBufferTypes[output.d],

View File

@@ -89,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return nil, nil, err return nil, nil, err
} }
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
if !cachePrompt {
numPast = 0
}
slot.InUse = true slot.InUse = true
slot.lastUsed = time.Now() slot.lastUsed = time.Now()

View File

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
}) })
} }
} }
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs)) params.numKeep = int32(len(inputs))
} }
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift // Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1) params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -366,17 +369,6 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize batchSize := s.batchSize
for j, inp := range seq.inputs { for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
}
// If we are required to put following inputs into a single batch then extend the // If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this // batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs. // will cause a break if we have pending inputs.
@@ -389,6 +381,20 @@ func (s *Server) processBatch() error {
break break
} }
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
}
options.Inputs = append(options.Inputs, inp.Token) options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil { if inp.Multimodal != nil {
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
@@ -590,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)

View File

@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits // topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK) tokens = topK(tokens, s.topK)
tokens = temperature(tokens, s.temperature) // scale and normalize the tokens in place
tokens = softmax(tokens) temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)

View File

@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
} }
// temperature applies scaling to the logits // temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token { func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability // Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7) temp = max(temp, 1e-7)
for i := range ts { for i := range ts {
ts[i].value = ts[i].value / temp ts[i].value = ts[i].value / temp
} }
return ts
} }
// softmax applies normalization to the logits // softmax applies normalization to the logits
func softmax(ts []token) []token { func softmax(ts []token) {
// Find max logit for numerical stability // Find max logit for numerical stability
maxLogit := float32(math.Inf(-1)) maxLogit := float32(math.Inf(-1))
for _, t := range ts { for _, t := range ts {
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
for i := range ts { for i := range ts {
ts[i].value /= sum ts[i].value /= sum
} }
return ts
} }
// topK limits the number of tokens considered to the k highest logits // topK limits the number of tokens considered to the k highest logits
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
} }
// topP limits tokens to those with cumulative probability p // topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token { func topP(ts []token, p float32) []token {
if p == 1.0 { if p == 1.0 {
return ts return ts
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
for i, t := range ts { for i, t := range ts {
sum += t.value sum += t.value
if sum > float32(p) { if sum > float32(p) {
ts = ts[:i+1] return ts[:i+1]
return ts
} }
} }
return ts return ts
} }
// minP limits tokens to those with cumulative probability p // minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token { func minP(ts []token, p float32) []token {
if p == 1.0 { maxProb := ts[0].value
return ts
}
maxProb := float32(math.Inf(-1)) threshold := maxProb * p
for _, token := range ts {
if token.value > maxProb { for i, t := range ts {
maxProb = token.value if t.value < threshold {
return ts[:i]
} }
} }
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts return ts
} }

View File

@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0} input := []float32{1.0, 4.0, -2.0, 0.0}
got := temperature(toTokens(input), 0.5) tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0} want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got) compareLogits(t, "temperature(0.5)", want, tokens)
got = temperature(toTokens(input), 1.0) input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0} want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got) compareLogits(t, "temperature(1)", want, tokens)
got = temperature(toTokens(input), 0.0) input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0} want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got) compareLogits(t, "temperature(0)", want, tokens)
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := softmax(toTokens(tt.input)) tokens := toTokens(tt.input)
softmax(tokens)
if tt.expected != nil { if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got) compareLogits(t, tt.name, tt.expected, tokens)
return return
} }
// Check probabilities sum to 1 // Check probabilities sum to 1
var sum float32 var sum float32
for _, token := range got { for _, token := range tokens {
sum += token.value sum += token.value
if token.value < 0 || token.value > 1 { if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value) t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test k=5 tokens = topK(tokens, 5)
got := topK(toTokens(input), 5) if len(tokens) != 5 {
if len(got) != 5 { t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
} }
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154} want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got) compareLogits(t, "topK(3)", want, tokens)
got = topK(toTokens(input), 20) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, 20)
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
} }
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, -1)
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
} }
compareLogits(t, "topK(-1)", want, got) compareLogits(t, "topK(-1)", want, tokens)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0) tokens = toTokens(input)
if len(got) != len(input) { tokens = topK(tokens, 0)
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
} }
compareLogits(t, "topK(-1)", want, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax to get probabilities // First apply temperature and softmax to get probabilities
tokens = softmax(tokens) softmax(tokens)
tokens = topK(tokens, 20) tokens = topK(tokens, 20)
// Then apply topP // Then apply topP
got := topP(tokens, 0.95) tokens = topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95 // Should keep tokens until cumsum > 0.95
if len(got) > 3 { if len(tokens) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got) t.Logf("got: %v", tokens)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e6} // One dominant token
tokens = toTokens(input)
softmax(tokens)
tokens = topP(tokens, 0.0) // Very small p
if len(tokens) < 1 {
t.Error("topP should keep at least one token")
} }
} }
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax // First apply temperature and softmax
tokens = softmax(tokens) tokens = topK(tokens, 20)
softmax(tokens)
// Then apply minP tokens = minP(tokens, 1.0)
got := minP(tokens, 0.2)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob // Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 { if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
} }
} }
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topK(tokensCopy, 10) tokens = topK(tokensCopy, 10)
} }
}) })
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topP(tokensCopy, 0.9) tokens = topP(tokensCopy, 0.9)
} }
}) })
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
minP(tokensCopy, 0.2) tokens = minP(tokensCopy, 0.2)
} }
}) })
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topK(tokensCopy, 200000) tokens = topK(tokensCopy, 200000)
} }
}) })
} }