From da0e345200fbb47653d2f9c60fcc60ba7b0a7187 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 18 Mar 2025 18:08:19 -0700 Subject: [PATCH] ml: use input context for extracting outputs (#9875) --- model/models/gemma2/model.go | 2 +- model/models/gemma3/model.go | 2 +- model/models/llama/model.go | 2 +- model/models/mllama/model.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2b8597c4..fbefebe2 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 32ad80f4..95f89ad4 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 19a2ab8c..87eb9b75 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index fa4d570c..0aa11f17 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err }