batch: use tensors for outputs (#12185)

this cleans up the model interface slightly without too much impact in
other areas
This commit is contained in:
Michael Yang
2025-09-15 14:33:06 -07:00
committed by GitHub
parent 92b96d54ef
commit 6f7117145f
14 changed files with 27 additions and 37 deletions

View File

@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if i == len(m.TransformerBlocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)