From c1149875234a51aa1e5e60b74f3807f5982c60fa Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 13 Nov 2025 13:49:25 -0800 Subject: [PATCH] logprob: add bytes to logprobs (#13068) --- api/types.go | 3 + integration/api_test.go | 25 +++ server/logprob.go | 15 ++ server/routes_generate_test.go | 269 +++++++++++++++++++++++++++++++++ 4 files changed, 312 insertions(+) diff --git a/api/types.go b/api/types.go index d5788d54..d8467629 100644 --- a/api/types.go +++ b/api/types.go @@ -366,6 +366,9 @@ type TokenLogprob struct { // Logprob is the log probability of this token. Logprob float64 `json:"logprob"` + + // Bytes contains the raw byte representation of the token + Bytes []int `json:"bytes,omitempty"` } // Logprob contains log probability information for a generated token. diff --git a/integration/api_test.go b/integration/api_test.go index 839e14d7..66d1e9e9 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -14,6 +14,23 @@ import ( "github.com/ollama/ollama/api" ) +func assertBytesMatchToken(t *testing.T, label, token string, ints []int) { + t.Helper() + + raw := []byte(token) + if len(ints) != len(raw) { + t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints) + return + } + + for i, b := range raw { + if ints[i] != int(b) { + t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b)) + return + } + } +} + func TestAPIGenerate(t *testing.T) { initialTimeout := 60 * time.Second streamTimeout := 30 * time.Second @@ -466,6 +483,7 @@ func TestAPIGenerateLogprobs(t *testing.T) { if lp.Logprob > 0 { t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob) } + assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes) // Check top_logprobs if requested if test.topLogprobs > 0 { @@ -482,6 +500,9 @@ func TestAPIGenerateLogprobs(t *testing.T) { t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob) } } + for j, top := range lp.TopLogprobs { + assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes) + } } else if len(lp.TopLogprobs) > 0 { t.Errorf("logprob[%d] has top_logprobs but none were requested", i) } @@ -544,11 +565,15 @@ func TestAPIChatLogprobs(t *testing.T) { if lp.Logprob > 0 { t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob) } + assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes) if len(lp.TopLogprobs) == 0 { t.Errorf("logprob[%d] expected top_logprobs but got none", i) } if len(lp.TopLogprobs) > 3 { t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs)) } + for j, top := range lp.TopLogprobs { + assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes) + } } } diff --git a/server/logprob.go b/server/logprob.go index 51996c2a..4a6e1408 100644 --- a/server/logprob.go +++ b/server/logprob.go @@ -12,6 +12,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob { result[i] = api.Logprob{ TokenLogprob: api.TokenLogprob{ Token: lp.Token, + Bytes: stringToByteInts(lp.Token), Logprob: lp.Logprob, }, } @@ -20,6 +21,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob { for j, tlp := range lp.TopLogprobs { result[i].TopLogprobs[j] = api.TokenLogprob{ Token: tlp.Token, + Bytes: stringToByteInts(tlp.Token), Logprob: tlp.Logprob, } } @@ -27,3 +29,16 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob { } return result } + +func stringToByteInts(s string) []int { + if s == "" { + return nil + } + + raw := []byte(s) + ints := make([]int, len(raw)) + for i, b := range raw { + ints[i] = int(b) + } + return ints +} diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index a6be3bf3..a9931ea2 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -1220,6 +1220,139 @@ func TestGenerateLogprobs(t *testing.T) { t.Errorf("mismatch (-got +want):\n%s", diff) } }) + + t.Run("returns logprob bytes when requested", func(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := &mockRunner{} + expectedPrimary := llm.TokenLogprob{ + Token: "Hi", + Logprob: -0.01, + } + expectedAlternatives := []llm.TokenLogprob{ + { + Token: "Hello", + Logprob: -0.25, + }, + { + Token: "Hey", + Logprob: -0.5, + }, + } + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + fn(llm.CompletionResponse{ + Content: "Hi", + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + Logprobs: []llm.Logprob{ + { + TokenLogprob: expectedPrimary, + TopLogprobs: expectedAlternatives, + }, + }, + }) + return nil + } + + s := &Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + waitForRecovery: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + req.successCh <- &runnerRef{llama: mock} + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + if w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-logprob-bytes", + Files: map[string]string{"file.gguf": digest}, + Template: `{{ .Prompt }}`, + Stream: &stream, + }); w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + stream := false + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-logprob-bytes", + Prompt: "Hi", + Stream: &stream, + Logprobs: true, + TopLogprobs: len(expectedAlternatives), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.GenerateResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if len(resp.Logprobs) != 1 { + t.Fatalf("expected 1 logprob entry, got %d", len(resp.Logprobs)) + } + + expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token) + expectedAlternativesBytes := make([][]int, len(expectedAlternatives)) + for i, alternative := range expectedAlternatives { + expectedAlternativesBytes[i] = stringToByteInts(alternative.Token) + } + if diff := cmp.Diff(expectedPrimaryBytes, resp.Logprobs[0].Bytes); diff != "" { + t.Fatalf("primary token bytes mismatch (-want +got):\n%s", diff) + } + + if len(resp.Logprobs[0].TopLogprobs) != len(expectedAlternatives) { + t.Fatalf("expected %d top logprobs, got %d", len(expectedAlternatives), len(resp.Logprobs[0].TopLogprobs)) + } + + for i, top := range resp.Logprobs[0].TopLogprobs { + if diff := cmp.Diff(expectedAlternativesBytes[i], top.Bytes); diff != "" { + t.Fatalf("top logprob[%d] bytes mismatch (-want +got):\n%s", i, diff) + } + } + }) } func TestChatLogprobs(t *testing.T) { @@ -1262,6 +1395,142 @@ func TestChatLogprobs(t *testing.T) { t.Errorf("mismatch (-got +want):\n%s", diff) } }) + + t.Run("returns logprob bytes when requested", func(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := &mockRunner{} + expectedPrimary := llm.TokenLogprob{ + Token: "Hi", + Logprob: -0.02, + } + expectedAlternatives := []llm.TokenLogprob{ + { + Token: "Hello", + Logprob: -0.3, + }, + { + Token: "Hey", + Logprob: -0.45, + }, + } + expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token) + expectedAlternativesBytes := make([][]int, len(expectedAlternatives)) + for i, alternative := range expectedAlternatives { + expectedAlternativesBytes[i] = stringToByteInts(alternative.Token) + } + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + fn(llm.CompletionResponse{ + Content: "Hi", + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + Logprobs: []llm.Logprob{ + { + TokenLogprob: expectedPrimary, + TopLogprobs: expectedAlternatives, + }, + }, + }) + return nil + } + + s := &Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + waitForRecovery: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + req.successCh <- &runnerRef{llama: mock} + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + if w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-chat-logprob-bytes", + Files: map[string]string{"file.gguf": digest}, + Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }} +{{ end }}`, + Stream: &stream, + }); w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + stream := false + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-chat-logprob-bytes", + Messages: []api.Message{ + {Role: "user", Content: "Say hi"}, + }, + Stream: &stream, + Logprobs: true, + TopLogprobs: len(expectedAlternatives), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if len(resp.Logprobs) != 1 { + t.Fatalf("expected 1 logprob entry, got %d", len(resp.Logprobs)) + } + + if diff := cmp.Diff(expectedPrimaryBytes, resp.Logprobs[0].Bytes); diff != "" { + t.Fatalf("primary token bytes mismatch (-want +got):\n%s", diff) + } + + if len(resp.Logprobs[0].TopLogprobs) != len(expectedAlternatives) { + t.Fatalf("expected %d top logprobs, got %d", len(expectedAlternatives), len(resp.Logprobs[0].TopLogprobs)) + } + + for i, top := range resp.Logprobs[0].TopLogprobs { + if diff := cmp.Diff(expectedAlternativesBytes[i], top.Bytes); diff != "" { + t.Fatalf("top logprob[%d] bytes mismatch (-want +got):\n%s", i, diff) + } + } + }) } func TestChatWithPromptEndingInThinkTag(t *testing.T) {