Reapply "add truncate and shift parameters" (#12582)

This commit is contained in:
Jeffrey Morgan
2025-10-11 16:06:14 -07:00
committed by GitHub
parent 5db8a818a1
commit 6544e14735
8 changed files with 298 additions and 57 deletions

View File

@@ -79,6 +79,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason
// Metrics
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
shift bool
truncate bool
}
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}

View File

@@ -88,6 +88,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason
// Metrics
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
numKeep int32
sampler sample.Sampler
embedding bool
shift bool
truncate bool
}
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shift: params.shift,
}, nil
}
@@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break
}
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
break
}
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
@@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}