From c116a7523ddc067db2b86aab38172c05ad01c710 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 28 Jul 2025 11:29:25 -0700 Subject: [PATCH] kvcache: Don't shift empty batches When we context shift, we delete half the context and apply RoPE with an offset to the other half. We used to RoPE across the entire context in a single pass with a zero offset for the deleted section. With the change to shifting in batches, we can skip any batches where all of the offsets would be zero. This typically reduces the number of operations by half. --- kvcache/causal.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 8b101a81..496eeaa6 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { seqRange := c.cellRanges[seq] for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { - ctx := c.backend.NewContext() - size := min(seqRange.max-start+1, c.maxBatch) offsets := make([]int32, size) + + var batchFirst, batchLast int + + batchFirst = -1 for i := range offsets { cell := c.cells[start+i] if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { offsets[i] = offset + if batchFirst < 0 { + batchFirst = i + } + batchLast = i } } + if batchFirst < 0 { + continue + } + + offsets = offsets[batchFirst : batchLast+1] + + ctx := c.backend.NewContext() kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) for i, key := range c.keys { @@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { numKVHeads := key.Dim(1) rowSize := key.Stride(2) - key = key.View(ctx, rowSize*start, + key = key.View(ctx, rowSize*(start+batchFirst), kHeadDim, key.Stride(1), numKVHeads, key.Stride(2), - size, + len(offsets), ) roped, err := c.shiftFn(ctx, i, key, kShift)