Revert "add truncate and shift parameters (#12519)" (#12545)

This reverts commit 6a62b894c7.
This commit is contained in:
Jeffrey Morgan
2025-10-08 17:57:57 -07:00
committed by GitHub
parent 6a62b894c7
commit 7d965258ce
8 changed files with 67 additions and 272 deletions

View File

@@ -106,14 +106,6 @@ type GenerateRequest struct {
// before this option was introduced) // before this option was introduced)
Think *ThinkValue `json:"think,omitempty"` Think *ThinkValue `json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate *bool `json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift *bool `json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered // DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model. // template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"` DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
@@ -148,14 +140,6 @@ type ChatRequest struct {
// for supported models. // for supported models.
Think *ThinkValue `json:"think,omitempty"` Think *ThinkValue `json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate *bool `json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift *bool `json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered // DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model. // template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"` DebugRenderOnly bool `json:"_debug_render_only,omitempty"`

View File

@@ -1380,8 +1380,6 @@ type CompletionRequest struct {
Options *api.Options Options *api.Options
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
Shift bool
Truncate bool
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@@ -1500,7 +1498,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("failed reading llm error response: %w", err) return fmt.Errorf("failed reading llm error response: %w", err)
} }
log.Printf("llm predict error: %s", bodyBytes) log.Printf("llm predict error: %s", bodyBytes)
return api.StatusError{StatusCode: res.StatusCode, Status: res.Status, ErrorMessage: strings.TrimSpace(string(bodyBytes))} return fmt.Errorf("%s", bodyBytes)
} }
scanner := bufio.NewScanner(res.Body) scanner := bufio.NewScanner(res.Body)

View File

@@ -79,9 +79,6 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason doneReason llm.DoneReason
// Metrics // Metrics
@@ -97,12 +94,8 @@ type NewSequenceParams struct {
numKeep int numKeep int
samplingParams *llama.SamplingParams samplingParams *llama.SamplingParams
embedding bool embedding bool
shift bool
truncate bool
} }
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
@@ -128,10 +121,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if len(inputs) > s.cache.numCtx { if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx discard := len(inputs) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
newInputs := inputs[:params.numKeep] newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...) newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -399,11 +388,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for i, input := range seq.inputs { for i, input := range seq.inputs {
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
if len(seq.pendingInputs) == 0 { if len(seq.pendingInputs) == 0 {
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil { if err != nil {
var reprocess *ErrReprocessInputs var reprocess *ErrReprocessInputs
@@ -599,14 +583,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.Options.NumKeep, numKeep: req.Options.NumKeep,
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
shift: req.Shift,
truncate: req.Truncate,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }

View File

@@ -88,9 +88,6 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason doneReason llm.DoneReason
// Metrics // Metrics
@@ -106,12 +103,8 @@ type NewSequenceParams struct {
numKeep int32 numKeep int32
sampler sample.Sampler sampler sample.Sampler
embedding bool embedding bool
shift bool
truncate bool
} }
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait() s.ready.Wait()
@@ -133,11 +126,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if int32(len(inputs)) > s.cache.numCtx { if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx discard := int32(len(inputs)) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
promptStart := params.numKeep + discard promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch // If we need to truncate in the middle of a unbreakable batch, remove the entire batch
@@ -190,7 +178,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding, embeddingOnly: params.embedding,
stop: params.stop, stop: params.stop,
numKeep: params.numKeep, numKeep: params.numKeep,
shift: params.shift,
}, nil }, nil
} }
@@ -535,12 +522,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break break
} }
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
break
}
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil { if err != nil {
var reprocess *ErrReprocessInputs var reprocess *ErrReprocessInputs
@@ -843,14 +824,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: int32(req.Options.NumKeep), numKeep: int32(req.Options.NumKeep),
sampler: sampler, sampler: sampler,
embedding: false, embedding: false,
shift: req.Shift,
truncate: req.Truncate,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }

View File

@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
if truncate && ctxLen > opts.NumCtx { if ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break break
} else { } else {

View File

@@ -30,7 +30,6 @@ func TestChatPrompt(t *testing.T) {
name string name string
model Model model Model
limit int limit int
truncate bool
msgs []api.Message msgs []api.Message
expect expect
}{ }{
@@ -38,7 +37,6 @@ func TestChatPrompt(t *testing.T) {
name: "messages", name: "messages",
model: visionModel, model: visionModel,
limit: 64, limit: 64,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -52,7 +50,6 @@ func TestChatPrompt(t *testing.T) {
name: "truncate messages", name: "truncate messages",
model: visionModel, model: visionModel,
limit: 1, limit: 1,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -66,7 +63,6 @@ func TestChatPrompt(t *testing.T) {
name: "truncate messages with image", name: "truncate messages with image",
model: visionModel, model: visionModel,
limit: 64, limit: 64,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -83,7 +79,6 @@ func TestChatPrompt(t *testing.T) {
name: "truncate messages with images", name: "truncate messages with images",
model: visionModel, model: visionModel,
limit: 64, limit: 64,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -100,7 +95,6 @@ func TestChatPrompt(t *testing.T) {
name: "messages with images", name: "messages with images",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -118,7 +112,6 @@ func TestChatPrompt(t *testing.T) {
name: "message with image tag", name: "message with image tag",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -136,7 +129,6 @@ func TestChatPrompt(t *testing.T) {
name: "messages with interleaved images", name: "messages with interleaved images",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
@@ -156,7 +148,6 @@ func TestChatPrompt(t *testing.T) {
name: "truncate message with interleaved images", name: "truncate message with interleaved images",
model: visionModel, model: visionModel,
limit: 1024, limit: 1024,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
@@ -175,7 +166,6 @@ func TestChatPrompt(t *testing.T) {
name: "message with system prompt", name: "message with system prompt",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."}, {Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
@@ -190,7 +180,6 @@ func TestChatPrompt(t *testing.T) {
name: "out of order system", name: "out of order system",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
@@ -205,7 +194,6 @@ func TestChatPrompt(t *testing.T) {
name: "multiple images same prompt", name: "multiple images same prompt",
model: visionModel, model: visionModel,
limit: 2048, limit: 2048,
truncate: true,
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}}, {Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
}, },
@@ -214,20 +202,6 @@ func TestChatPrompt(t *testing.T) {
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")}, images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
}, },
}, },
{
name: "no truncate with limit exceeded",
model: visionModel,
limit: 10,
truncate: false,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
} }
for _, tt := range cases { for _, tt := range cases {
@@ -235,7 +209,7 @@ func TestChatPrompt(t *testing.T) {
model := tt.model model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate) prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think})
if tt.error == nil && err != nil { if tt.error == nil && err != nil {
t.Fatal(err) t.Fatal(err)
} else if tt.error != nil && err != tt.error { } else if tt.error != nil && err != tt.error {

View File

@@ -472,8 +472,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
@@ -535,7 +533,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch <- res ch <- res
}); err != nil { }); err != nil {
ch <- err ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -549,11 +547,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sbThinking.WriteString(t.Thinking) sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response) sbContent.WriteString(t.Response)
r = t r = t
case api.StatusError: case gin.H:
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage}) msg, ok := t["error"].(string)
return if !ok {
case error: msg = "unexpected error format in response"
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()}) }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
@@ -1618,18 +1618,6 @@ func streamResponse(c *gin.Context, ch chan any) {
return false return false
} }
if statusError, ok := val.(api.StatusError); ok {
c.Header("Content-Type", "application/json")
c.AbortWithStatusJSON(statusError.StatusCode, gin.H{"error": statusError.ErrorMessage})
return false
}
if err, ok := val.(error); ok {
c.Header("Content-Type", "application/json")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return false
}
bts, err := json.Marshal(val) bts, err := json.Marshal(val)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err)) slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
@@ -1947,8 +1935,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
truncate := req.Truncate == nil || *req.Truncate prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -2001,8 +1988,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -2075,7 +2060,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch <- res ch <- res
}); err != nil { }); err != nil {
ch <- err ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -2093,11 +2078,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...) toolCalls = append(toolCalls, t.Message.ToolCalls...)
} }
case api.StatusError: case gin.H:
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage}) msg, ok := t["error"].(string)
return if !ok {
case error: msg = "unexpected error format in response"
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()}) }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})

View File

@@ -594,58 +594,6 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff) t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
} }
}) })
t.Run("status error non-streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusServiceUnavailable,
Status: "Service Unavailable",
ErrorMessage: "model is overloaded",
}
}
stream := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusTooManyRequests,
Status: "Too Many Requests",
ErrorMessage: "rate limit exceeded",
}
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
})
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected status 429, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
} }
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {
@@ -1020,55 +968,6 @@ func TestGenerate(t *testing.T) {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
t.Run("status error non-streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusServiceUnavailable,
Status: "Service Unavailable",
ErrorMessage: "model is overloaded",
}
}
streamRequest := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &streamRequest,
})
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusTooManyRequests,
Status: "Too Many Requests",
ErrorMessage: "rate limit exceeded",
}
}
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected status 429, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
} }
func TestChatWithPromptEndingInThinkTag(t *testing.T) { func TestChatWithPromptEndingInThinkTag(t *testing.T) {