batch: use tensors for outputs (#12185)

this cleans up the model interface slightly without too much impact in
other areas
This commit is contained in:
Michael Yang
2025-09-15 14:33:06 -07:00
committed by GitHub
parent 92b96d54ef
commit 6f7117145f
14 changed files with 27 additions and 37 deletions

View File

@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil