mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
sample: add error handling for empty logits (#9740)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
)
|
||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// Should get the token with the highest logit
|
||||
want = int32(0)
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
|
||||
Reference in New Issue
Block a user