mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-25 07:58:01 +00:00
server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter in the API, Ollama will return the log probabilities for each token generated. 'top_logprobs', an integer value can also be specified up to the value 20. When specified, the API will also provide the number of most likely tokens to return at each token position Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
This commit is contained in:
@@ -40,22 +40,29 @@ type Message struct {
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ChoiceLogprobs struct {
|
||||
Content []api.Logprob `json:"content"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type ChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type CompleteChunkChoice struct {
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
@@ -104,6 +111,8 @@ type ChatCompletionRequest struct {
|
||||
Tools []api.Tool `json:"tools"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
Logprobs *bool `json:"logprobs"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
@@ -142,6 +151,7 @@ type CompletionRequest struct {
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
Logprobs *int `json:"logprobs"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
@@ -251,6 +261,12 @@ func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
||||
|
||||
var logprobs *ChoiceLogprobs
|
||||
if len(r.Logprobs) > 0 {
|
||||
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
||||
}
|
||||
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
@@ -269,6 +285,7 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
Logprobs: logprobs,
|
||||
}}, Usage: ToUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
@@ -277,6 +294,12 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
||||
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
||||
|
||||
var logprobs *ChoiceLogprobs
|
||||
if len(r.Logprobs) > 0 {
|
||||
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
||||
}
|
||||
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
@@ -295,6 +318,7 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
Logprobs: logprobs,
|
||||
}},
|
||||
}
|
||||
}
|
||||
@@ -604,6 +628,8 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
Logprobs: r.Logprobs != nil && *r.Logprobs,
|
||||
TopLogprobs: r.TopLogprobs,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
@@ -680,12 +706,21 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var logprobs bool
|
||||
var topLogprobs int
|
||||
if r.Logprobs != nil && *r.Logprobs > 0 {
|
||||
logprobs = true
|
||||
topLogprobs = *r.Logprobs
|
||||
}
|
||||
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
Logprobs: logprobs,
|
||||
TopLogprobs: topLogprobs,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
@@ -218,3 +219,218 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_WithLogprobs(t *testing.T) {
|
||||
trueVal := true
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Logprobs: &trueVal,
|
||||
TopLogprobs: 5,
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Logprobs {
|
||||
t.Error("expected Logprobs to be true")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 5 {
|
||||
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_LogprobsDefault(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Logprobs {
|
||||
t.Error("expected Logprobs to be false by default")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 0 {
|
||||
t.Errorf("expected TopLogprobs to be 0 by default, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromCompleteRequest_WithLogprobs(t *testing.T) {
|
||||
logprobsVal := 5
|
||||
|
||||
req := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Logprobs: &logprobsVal,
|
||||
}
|
||||
|
||||
result, err := FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Logprobs {
|
||||
t.Error("expected Logprobs to be true")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 5 {
|
||||
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatCompletion_WithLogprobs(t *testing.T) {
|
||||
createdAt := time.Unix(1234567890, 0)
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: createdAt,
|
||||
Message: api.Message{Role: "assistant", Content: "Hello there"},
|
||||
Logprobs: []api.Logprob{
|
||||
{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: "Hello",
|
||||
Logprob: -0.5,
|
||||
},
|
||||
TopLogprobs: []api.TokenLogprob{
|
||||
{Token: "Hello", Logprob: -0.5},
|
||||
{Token: "Hi", Logprob: -1.2},
|
||||
},
|
||||
},
|
||||
{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: " there",
|
||||
Logprob: -0.3,
|
||||
},
|
||||
TopLogprobs: []api.TokenLogprob{
|
||||
{Token: " there", Logprob: -0.3},
|
||||
{Token: " world", Logprob: -1.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
id := "test-id"
|
||||
|
||||
result := ToChatCompletion(id, resp)
|
||||
|
||||
if result.Id != id {
|
||||
t.Errorf("expected Id %q, got %q", id, result.Id)
|
||||
}
|
||||
|
||||
if result.Created != 1234567890 {
|
||||
t.Errorf("expected Created %d, got %d", int64(1234567890), result.Created)
|
||||
}
|
||||
|
||||
if len(result.Choices) != 1 {
|
||||
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
|
||||
}
|
||||
|
||||
choice := result.Choices[0]
|
||||
if choice.Message.Content != "Hello there" {
|
||||
t.Errorf("expected content %q, got %q", "Hello there", choice.Message.Content)
|
||||
}
|
||||
|
||||
if choice.Logprobs == nil {
|
||||
t.Fatal("expected Logprobs to be present")
|
||||
}
|
||||
|
||||
if len(choice.Logprobs.Content) != 2 {
|
||||
t.Fatalf("expected 2 logprobs, got %d", len(choice.Logprobs.Content))
|
||||
}
|
||||
|
||||
// Verify first logprob
|
||||
if choice.Logprobs.Content[0].Token != "Hello" {
|
||||
t.Errorf("expected first token %q, got %q", "Hello", choice.Logprobs.Content[0].Token)
|
||||
}
|
||||
if choice.Logprobs.Content[0].Logprob != -0.5 {
|
||||
t.Errorf("expected first logprob -0.5, got %f", choice.Logprobs.Content[0].Logprob)
|
||||
}
|
||||
if len(choice.Logprobs.Content[0].TopLogprobs) != 2 {
|
||||
t.Errorf("expected 2 top_logprobs, got %d", len(choice.Logprobs.Content[0].TopLogprobs))
|
||||
}
|
||||
|
||||
// Verify second logprob
|
||||
if choice.Logprobs.Content[1].Token != " there" {
|
||||
t.Errorf("expected second token %q, got %q", " there", choice.Logprobs.Content[1].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatCompletion_WithoutLogprobs(t *testing.T) {
|
||||
createdAt := time.Unix(1234567890, 0)
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: createdAt,
|
||||
Message: api.Message{Role: "assistant", Content: "Hello"},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
id := "test-id"
|
||||
|
||||
result := ToChatCompletion(id, resp)
|
||||
|
||||
if len(result.Choices) != 1 {
|
||||
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
|
||||
}
|
||||
|
||||
// When no logprobs, Logprobs should be nil
|
||||
if result.Choices[0].Logprobs != nil {
|
||||
t.Error("expected Logprobs to be nil when not requested")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
topLogprobs int
|
||||
expectValid bool
|
||||
}{
|
||||
{name: "valid: 0", topLogprobs: 0, expectValid: true},
|
||||
{name: "valid: 1", topLogprobs: 1, expectValid: true},
|
||||
{name: "valid: 10", topLogprobs: 10, expectValid: true},
|
||||
{name: "valid: 20", topLogprobs: 20, expectValid: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
trueVal := true
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Logprobs: &trueVal,
|
||||
TopLogprobs: tt.topLogprobs,
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.TopLogprobs != tt.topLogprobs {
|
||||
t.Errorf("expected TopLogprobs %d, got %d", tt.topLogprobs, result.TopLogprobs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user