input: Rename Options to Batch

Options is no longer very descriptive of this struct.
This commit is contained in:
Jesse Gross
2025-03-19 14:28:15 -07:00
committed by Jesse Gross
parent ffbfe833da
commit 0c220935bd
14 changed files with 73 additions and 60 deletions

View File

@@ -52,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error
StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -140,10 +140,10 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
var err error
@@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
}
c.curCellRange = newRange()
for i, pos := range opts.Positions {
seq := opts.Sequences[i]
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}

View File

@@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}

View File

@@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image
if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
return nil

View File

@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, opts)
err := cache.StartForward(ctx, batch)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range opts.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
}
}
return err