sample: separate softmax and temperature transforms (#9732)

This commit is contained in:
Parth Sareen
2025-03-13 09:53:27 -07:00
committed by GitHub
parent 4aeb67ef4c
commit 5c0b663969
3 changed files with 98 additions and 25 deletions

View File

@@ -25,8 +25,18 @@ func (h *tokenHeap) Pop() any {
return x
}
// temperature applies scaling and softmax to the logits
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
@@ -35,15 +45,14 @@ func temperature(ts []token, temp float32) []token {
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
// Compute exp(x - max)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// Normalize
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}