mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
Adds logprobs support to Ollama's API including support for Ollama's OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter in the API, Ollama will return the log probabilities for each token generated. 'top_logprobs', an integer value can also be specified up to the value 20. When specified, the API will also provide the number of most likely tokens to return at each token position Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
80 lines
1.9 KiB
Go
80 lines
1.9 KiB
Go
package common
|
|
|
|
import (
|
|
"math"
|
|
"sort"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
// TokenDecoderFunc is a function that converts token IDs to text.
|
|
type TokenDecoderFunc func(tokenID int) string
|
|
|
|
// CalculateLogprobs converts raw logits to log probabilities and finds top K tokens.
|
|
// It uses numerically stable softmax to compute log probabilities.
|
|
func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder TokenDecoderFunc) []llm.Logprob {
|
|
if len(logits) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Step 1: Convert logits to log probabilities using numerically stable softmax
|
|
maxLogit := logits[0]
|
|
for _, logit := range logits[1:] {
|
|
if logit > maxLogit {
|
|
maxLogit = logit
|
|
}
|
|
}
|
|
|
|
var sumExp float64
|
|
for _, logit := range logits {
|
|
sumExp += math.Exp(float64(logit - maxLogit))
|
|
}
|
|
logSumExp := float32(math.Log(sumExp))
|
|
|
|
logProbs := make([]float32, len(logits))
|
|
for i, logit := range logits {
|
|
logProbs[i] = (logit - maxLogit) - logSumExp
|
|
}
|
|
|
|
// Step 2: Get selected token's information
|
|
selectedLogprob := logProbs[selectedToken]
|
|
selectedText := decoder(selectedToken)
|
|
|
|
result := llm.Logprob{
|
|
TokenLogprob: llm.TokenLogprob{
|
|
Token: selectedText,
|
|
Logprob: float64(selectedLogprob),
|
|
},
|
|
}
|
|
|
|
// Step 3: If topK requested, find the top K tokens
|
|
if topK > 0 {
|
|
type tokenLogprobPair struct {
|
|
tokenID int
|
|
logprob float32
|
|
}
|
|
|
|
pairs := make([]tokenLogprobPair, len(logProbs))
|
|
for i, lp := range logProbs {
|
|
pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp}
|
|
}
|
|
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
return pairs[i].logprob > pairs[j].logprob
|
|
})
|
|
|
|
k := min(topK, len(pairs))
|
|
topLogprobs := make([]llm.TokenLogprob, k)
|
|
for i := range k {
|
|
tokenText := decoder(pairs[i].tokenID)
|
|
topLogprobs[i] = llm.TokenLogprob{
|
|
Token: tokenText,
|
|
Logprob: float64(pairs[i].logprob),
|
|
}
|
|
}
|
|
result.TopLogprobs = topLogprobs
|
|
}
|
|
|
|
return []llm.Logprob{result}
|
|
}
|