mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
sample: do all sorting in topK
This commit is contained in:
committed by
Parth Sareen
parent
3ba91634c1
commit
4aeb67ef4c
@@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []token, k int) []token {
|
||||
if k >= len(ts) {
|
||||
sortLogits(ts)
|
||||
if k >= len(ts) || k <= 0 {
|
||||
slices.SortFunc(ts, func(a, b token) int {
|
||||
switch {
|
||||
case a.value < b.value:
|
||||
return 1
|
||||
case a.value > b.value:
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return ts
|
||||
}
|
||||
|
||||
@@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
// sortLogits sorts the tokens in descending order of logits
|
||||
func sortLogits(ts []token) {
|
||||
slices.SortFunc(ts, func(a, b token) int {
|
||||
switch {
|
||||
case a.value < b.value:
|
||||
return 1
|
||||
case a.value > b.value:
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user