mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
This commit is contained in:
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) []token {
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
threshold := maxProb * p
|
||||
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user