Fix follow up images and images split across batches

This commit is contained in:
Jesse Gross
2025-03-09 21:29:58 -07:00
committed by Michael Yang
parent e95278932b
commit 2c40c4d35e
2 changed files with 73 additions and 47 deletions

View File

@@ -173,24 +173,53 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
var embedding ml.Tensor
var src, dst, length int
var except []int32
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, positions[imageDst])
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
except := make([]int32, visionOutputs.Dim(1))
for i := 0; i < visionOutputs.Dim(1); i++ {
except[i] = int32(offset + i)
}
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs