diff --git a/kvcache/causal.go b/kvcache/causal.go index b594d0b4..8b101a81 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -25,6 +25,9 @@ type Causal struct { opts CausalOptions + // maxBatch is the largest batch that we might receive + maxBatch int + // config controls mostly backend-specific optimizations config *ml.CacheConfig @@ -147,6 +150,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity c.DType = dtype c.cellRanges = make(map[int]cellRange) c.backend = backend + c.maxBatch = maxBatch } func (c *Causal) SetConfig(config ml.CacheConfig) { @@ -639,48 +643,51 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { return ErrNotSupported } - ctx := c.backend.NewContext() - defer ctx.Close() - seqRange := c.cellRanges[seq] - size := seqRange.max - seqRange.min + 1 - offsets := make([]int32, size) - for i := range offsets { - cell := c.cells[seqRange.min+i] + for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { + ctx := c.backend.NewContext() - if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { - offsets[i] = offset + size := min(seqRange.max-start+1, c.maxBatch) + offsets := make([]int32, size) + for i := range offsets { + cell := c.cells[start+i] + + if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { + offsets[i] = offset + } } + + kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) + + for i, key := range c.keys { + if key == nil { + continue + } + + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*start, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + size, + ) + + roped, err := c.shiftFn(ctx, i, key, kShift) + if err != nil { + ctx.Close() + return err + } + + ctx.Forward(roped.Copy(ctx, key)) + } + + ctx.Compute() + ctx.Close() } - kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) - - for i, key := range c.keys { - if key == nil { - continue - } - - kHeadDim := key.Dim(0) - numKVHeads := key.Dim(1) - rowSize := key.Stride(2) - - key = key.View(ctx, rowSize*seqRange.min, - kHeadDim, key.Stride(1), - numKVHeads, key.Stride(2), - size, - ) - - roped, err := c.shiftFn(ctx, i, key, kShift) - if err != nil { - return err - } - - ctx.Forward(roped.Copy(ctx, key)) - } - - ctx.Compute() - return nil }