chore: update models to use slice/chunk/chunksections (#12934)

* use slice/chunks

* bert

* llama4

* gemma3n

* gptoss

* mistral3

* qwen3vl

* qwen25vl

* deepseek2

* remove unused ops
This commit is contained in:
Michael Yang
2025-11-13 15:20:12 -08:00
committed by GitHub
parent c114987523
commit 333203d871
13 changed files with 59 additions and 135 deletions

View File

@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") { if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: name + ".weight", Name: name,
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4, WriterTo: mxfp4,
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps // gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight", Name: strings.Replace(name, "gate_up", "gate", 1),
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{ }, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight", Name: strings.Replace(name, "gate_up", "up", 1),
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -146,7 +146,6 @@ type Tensor interface {
FromFloats([]float32) FromFloats([]float32)
FromInts([]int32) FromInts([]int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
@@ -185,7 +184,6 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, shape ...int) Tensor Contiguous(ctx Context, shape ...int) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor Pad(ctx Context, shape ...int) Tensor
@@ -209,7 +207,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
} }
// ScaledDotProductAttention implements a fused attention // ScaledDotProductAttention implements a fused attention

View File

@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
} }
} }
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
} }
} }
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {
@@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps. // Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range. // Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. // If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.

View File

@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS: case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, 0, 1, 1)
case TypeLast: case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
return hiddenStates
default: default:
panic("unknown pooling type") panic("unknown pooling type")
} }

View File

@@ -29,7 +29,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions)))) hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)

View File

@@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
} }
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
qPass := query.View(ctx, 0,
opts.qkNopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
opts.qkRopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
compressedKV := attn.KVA.Forward(ctx, hiddenStates) compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) kRot := compressedKV.View(ctx,
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
opts.qkRopeHeadDim, compressedKV.Stride(1), compressedKV.Stride(1), 1,
1, compressedKV.Stride(1), compressedKV.Stride(1), compressedKV.Dim(1),
compressedKV.Dim(1)) )
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass) kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
opts.vHeadDim, kv.Stride(1),
kv.Dim(1), kv.Stride(2),
kv.Dim(2)).Contiguous(ctx)
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, qPass, 0) query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kPass, 0) key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention) return attn.Output.Forward(ctx, attention)
} }
@@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights) experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ { for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))

View File

@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
cache.(*kvcache.WrapperCache).SetLayerType(layerType) cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :] // inputPerLayer = inputsPerLayer[:, i, :].squeeze(1)
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx) inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions) hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
} }
// hiddenStates = hiddenStates[:, :, 0] // hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1)) hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1)
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx) targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1) targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:] // hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1) hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState) altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx))) altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:] // inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1) inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1)
active = inactive.Add(ctx, active) active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1)) predictions0 := predictions.Slice(ctx, 2, 0, 1, 1)
return predictions0.Concat(ctx, active, 2) return predictions0.Concat(ctx, active, 2)
} }
@@ -319,7 +319,7 @@ type TextOptions struct {
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor { func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex] // t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1)) return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1)
} }
func (o *TextOptions) headDim() int { func (o *TextOptions) headDim() int {

View File

@@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
var query, key, value ml.Tensor var query, key, value ml.Tensor
if attn.QKV != nil { if attn.QKV != nil {
qkv := attn.QKV.Forward(ctx, hiddenStates) qkv := attn.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, batchSize)
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim) chunks := qkv.ChunkSections(ctx, 1, opts.numHeads, opts.numKVHeads, opts.numKVHeads)
query = qkv.View(ctx, query, key, value = chunks[0], chunks[1], chunks[2]
0,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numHeads, qkv.Stride(1),
batchSize,
)
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
key = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*opts.numHeads,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
value = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
} else { } else {
query = attn.Query.Forward(ctx, hiddenStates) query = attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
@@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
var gate, up ml.Tensor var gate, up ml.Tensor
if mlp.GateUp != nil { if mlp.GateUp != nil {
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts) hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2)) gate = hiddenStates.Slice(ctx, 0, 0, hiddenStates.Dim(0), 2)
up = hiddenStates.Slice(ctx, 0, 1, hiddenStates.Dim(0), 2)
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
gate = hiddenStates.View(ctx, 0, dimStride...)
gate = gate.Contiguous(ctx, gate.Dim(0)*gate.Dim(1), gate.Dim(2), gate.Dim(3))
up = hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
up = up.Contiguous(ctx, up.Dim(0)*up.Dim(1), up.Dim(2), up.Dim(3))
} else { } else {
gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts) gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts)
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)

View File

@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
for range aspectRatio.Y { for range aspectRatio.Y {
for x := range aspectRatio.X { for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
var separator separator var separator separator
if x < aspectRatio.X-1 { if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|> separator.x = true // <|tile_x_separator|>
@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
} }
} }
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil return multimodal, nil

View File

@@ -37,27 +37,23 @@ type VisionAttention struct {
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2] // t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2] // t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut) return cosOut.Add(ctx, sinOut)
} }

View File

@@ -11,9 +11,9 @@ import (
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {

View File

@@ -13,9 +13,9 @@ import (
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {

View File

@@ -18,8 +18,8 @@ type VisionAttention struct {
} }
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
@@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor
positionEmbeds = positionEmbeds.Mul(ctx, weights) positionEmbeds = positionEmbeds.Mul(ctx, weights)
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4) positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width). positionEmbedsChunks := positionEmbeds.Chunk(ctx, 2, 1)
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). positionEmbeds = positionEmbedsChunks[0].
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). Add(ctx, positionEmbedsChunks[1]).
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)) Add(ctx, positionEmbedsChunks[2]).
Add(ctx, positionEmbedsChunks[3])
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize) positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1) positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)