diff --git a/kvcache/causal.go b/kvcache/causal.go index 543a65a6..c7b3595e 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -40,11 +40,6 @@ type Causal struct { // ** current forward pass ** - // curReserve indicates that this forward pass is only for - // memory reservation and we should not update our metadata - // based on it. - curReserve bool - // the active layer for Get and Put curLayer int @@ -206,13 +201,12 @@ func (c *Causal) Close() { } func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { - c.curReserve = reserve c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - if !c.curReserve { + if !reserve { c.updateSlidingWindow() var err error @@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { length := c.curCellRange.max - c.curCellRange.min + 1 - if c.curReserve { - return ctx.Input().Empty(c.config.MaskDType, length, batchSize) - } - mask := make([]float32, batchSize*length) for i := range c.curBatchSize {