From 0d140bd1af59def462a0d3fe61c89b468162b5e7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 29 Oct 2025 11:03:43 -0700 Subject: [PATCH] fix: conv2d bias (#12834) --- ml/nn/convolution.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml/nn/convolution.go b/ml/nn/convolution.go index db8c6147..2954de00 100644 --- a/ml/nn/convolution.go +++ b/ml/nn/convolution.go @@ -10,7 +10,8 @@ type Conv2D struct { func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1) if m.Bias != nil { - t = t.Add(ctx, m.Bias) + // Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch) + t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1)) } return t }