mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
chore: update models to use slice/chunk/chunksections (#12934)
* use slice/chunks * bert * llama4 * gemma3n * gptoss * mistral3 * qwen3vl * qwen25vl * deepseek2 * remove unused ops
This commit is contained in:
@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
|
|
||||||
for name, mxfp4 := range mxfp4s {
|
for name, mxfp4 := range mxfp4s {
|
||||||
dims := mxfp4.blocks.Shape()
|
dims := mxfp4.blocks.Shape()
|
||||||
|
if !strings.HasSuffix(name, ".weight") {
|
||||||
|
name = name + ".weight"
|
||||||
|
}
|
||||||
if strings.Contains(name, "ffn_down_exps") {
|
if strings.Contains(name, "ffn_down_exps") {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: name + ".weight",
|
Name: name,
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4,
|
WriterTo: mxfp4,
|
||||||
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
Name: strings.Replace(name, "gate_up", "gate", 1),
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||||
}, &ggml.Tensor{
|
}, &ggml.Tensor{
|
||||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
Name: strings.Replace(name, "gate_up", "up", 1),
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||||
|
|||||||
@@ -146,7 +146,6 @@ type Tensor interface {
|
|||||||
FromFloats([]float32)
|
FromFloats([]float32)
|
||||||
FromInts([]int32)
|
FromInts([]int32)
|
||||||
|
|
||||||
Neg(ctx Context) Tensor
|
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Sub(ctx Context, t2 Tensor) Tensor
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
@@ -185,7 +184,6 @@ type Tensor interface {
|
|||||||
View(ctx Context, offset int, shape ...int) Tensor
|
View(ctx Context, offset int, shape ...int) Tensor
|
||||||
Permute(ctx Context, shape ...int) Tensor
|
Permute(ctx Context, shape ...int) Tensor
|
||||||
Contiguous(ctx Context, shape ...int) Tensor
|
Contiguous(ctx Context, shape ...int) Tensor
|
||||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
|
||||||
|
|
||||||
Pad(ctx Context, shape ...int) Tensor
|
Pad(ctx Context, shape ...int) Tensor
|
||||||
|
|
||||||
@@ -209,7 +207,6 @@ type Tensor interface {
|
|||||||
Stddev(ctx Context) Tensor
|
Stddev(ctx Context) Tensor
|
||||||
Sqr(ctx Context) Tensor
|
Sqr(ctx Context) Tensor
|
||||||
Sqrt(ctx Context) Tensor
|
Sqrt(ctx Context) Tensor
|
||||||
Clamp(ctx Context, min, max float32) Tensor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScaledDotProductAttention implements a fused attention
|
// ScaledDotProductAttention implements a fused attention
|
||||||
|
|||||||
@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
|
||||||
var tt *C.struct_ggml_tensor
|
|
||||||
switch len(strides) {
|
|
||||||
case 0:
|
|
||||||
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
|
|
||||||
case 1:
|
|
||||||
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
|
|
||||||
default:
|
|
||||||
panic("unsupported number of dimensions")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Tensor{b: t.b, t: tt}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
||||||
var kqMask *C.struct_ggml_tensor
|
var kqMask *C.struct_ggml_tensor
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
@@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||||
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||||
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||||
|
|||||||
@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
|||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
case TypeCLS:
|
case TypeCLS:
|
||||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||||
case TypeLast:
|
case TypeLast:
|
||||||
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||||
return hiddenStates
|
|
||||||
default:
|
default:
|
||||||
panic("unknown pooling type")
|
panic("unknown pooling type")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type Model struct {
|
|||||||
// Forward implements model.Model.
|
// Forward implements model.Model.
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1))
|
||||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
|
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
|
||||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
|
|||||||
@@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||||
|
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||||
qPass := query.View(ctx, 0,
|
|
||||||
opts.qkNopeHeadDim, query.Stride(1),
|
|
||||||
query.Dim(1), query.Stride(2),
|
|
||||||
query.Dim(2))
|
|
||||||
|
|
||||||
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
|
|
||||||
opts.qkRopeHeadDim, query.Stride(1),
|
|
||||||
query.Dim(1), query.Stride(2),
|
|
||||||
query.Dim(2))
|
|
||||||
|
|
||||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||||
|
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||||
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
|
kRot := compressedKV.View(ctx,
|
||||||
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
|
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||||
opts.qkRopeHeadDim, compressedKV.Stride(1),
|
compressedKV.Stride(1), 1,
|
||||||
1, compressedKV.Stride(1),
|
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||||
compressedKV.Dim(1))
|
)
|
||||||
|
|
||||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||||
kPass = attn.KVB.Forward(ctx, kPass)
|
kPass = attn.KVB.Forward(ctx, kPass)
|
||||||
|
|
||||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||||
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
|
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||||
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
|
|
||||||
opts.vHeadDim, kv.Stride(1),
|
|
||||||
kv.Dim(1), kv.Stride(2),
|
|
||||||
kv.Dim(2)).Contiguous(ctx)
|
|
||||||
|
|
||||||
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
|
|
||||||
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
|
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||||
|
|
||||||
query = qRot.Concat(ctx, qPass, 0)
|
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||||
key := kRot.Concat(ctx, kPass, 0)
|
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||||
return attn.Output.Forward(ctx, attention)
|
return attn.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
@@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
|
|||||||
|
|
||||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||||
experts = experts.Mul(ctx, topKWeights)
|
experts = experts.Mul(ctx, topKWeights)
|
||||||
|
|
||||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||||
|
|||||||
@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
||||||
|
|
||||||
// inputPerLayer = inputsPerLayer[:, i, :]
|
// inputPerLayer = inputsPerLayer[:, i, :].squeeze(1)
|
||||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx)
|
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hiddenStates = hiddenStates[:, :, 0]
|
// hiddenStates = hiddenStates[:, :, 0]
|
||||||
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
|
hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1)
|
||||||
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||||
|
|
||||||
// hiddenState = hiddenStates[:, :, 1:]
|
// hiddenState = hiddenStates[:, :, 1:]
|
||||||
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
|
hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1)
|
||||||
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
||||||
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||||
|
|
||||||
@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
|||||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||||
|
|
||||||
// inactive := predictions[:, :, 1:]
|
// inactive := predictions[:, :, 1:]
|
||||||
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
|
inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1)
|
||||||
active = inactive.Add(ctx, active)
|
active = inactive.Add(ctx, active)
|
||||||
|
|
||||||
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
|
predictions0 := predictions.Slice(ctx, 2, 0, 1, 1)
|
||||||
return predictions0.Concat(ctx, active, 2)
|
return predictions0.Concat(ctx, active, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ type TextOptions struct {
|
|||||||
|
|
||||||
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
// t[:, :, o.altupActiveIndex]
|
// t[:, :, o.altupActiveIndex]
|
||||||
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
|
return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *TextOptions) headDim() int {
|
func (o *TextOptions) headDim() int {
|
||||||
|
|||||||
@@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
|
|||||||
var query, key, value ml.Tensor
|
var query, key, value ml.Tensor
|
||||||
if attn.QKV != nil {
|
if attn.QKV != nil {
|
||||||
qkv := attn.QKV.Forward(ctx, hiddenStates)
|
qkv := attn.QKV.Forward(ctx, hiddenStates)
|
||||||
|
qkv = qkv.Reshape(ctx, opts.headDim(), -1, batchSize)
|
||||||
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim)
|
chunks := qkv.ChunkSections(ctx, 1, opts.numHeads, opts.numKVHeads, opts.numKVHeads)
|
||||||
query = qkv.View(ctx,
|
query, key, value = chunks[0], chunks[1], chunks[2]
|
||||||
0,
|
|
||||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
|
||||||
opts.numHeads, qkv.Stride(1),
|
|
||||||
batchSize,
|
|
||||||
)
|
|
||||||
|
|
||||||
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
|
|
||||||
key = qkv.View(ctx,
|
|
||||||
qkv.Stride(0)*opts.headDim()*opts.numHeads,
|
|
||||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
|
||||||
opts.numKVHeads, qkv.Stride(1),
|
|
||||||
batchSize,
|
|
||||||
)
|
|
||||||
|
|
||||||
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
|
|
||||||
value = qkv.View(ctx,
|
|
||||||
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
|
|
||||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
|
||||||
opts.numKVHeads, qkv.Stride(1),
|
|
||||||
batchSize,
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
query = attn.Query.Forward(ctx, hiddenStates)
|
query = attn.Query.Forward(ctx, hiddenStates)
|
||||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
@@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
|||||||
var gate, up ml.Tensor
|
var gate, up ml.Tensor
|
||||||
if mlp.GateUp != nil {
|
if mlp.GateUp != nil {
|
||||||
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
|
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
|
gate = hiddenStates.Slice(ctx, 0, 0, hiddenStates.Dim(0), 2)
|
||||||
|
up = hiddenStates.Slice(ctx, 0, 1, hiddenStates.Dim(0), 2)
|
||||||
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
|
|
||||||
|
|
||||||
gate = hiddenStates.View(ctx, 0, dimStride...)
|
|
||||||
gate = gate.Contiguous(ctx, gate.Dim(0)*gate.Dim(1), gate.Dim(2), gate.Dim(3))
|
|
||||||
|
|
||||||
up = hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
|
|
||||||
up = up.Contiguous(ctx, up.Dim(0)*up.Dim(1), up.Dim(2), up.Dim(3))
|
|
||||||
} else {
|
} else {
|
||||||
gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts)
|
gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
|
|||||||
@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
|
|
||||||
for range aspectRatio.Y {
|
for range aspectRatio.Y {
|
||||||
for x := range aspectRatio.X {
|
for x := range aspectRatio.X {
|
||||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
|
||||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
|
||||||
patchesPerChunk)
|
|
||||||
var separator separator
|
var separator separator
|
||||||
if x < aspectRatio.X-1 {
|
if x < aspectRatio.X-1 {
|
||||||
separator.x = true // <|tile_x_separator|>
|
separator.x = true // <|tile_x_separator|>
|
||||||
@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
|
||||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
|
||||||
patchesPerChunk)
|
|
||||||
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
|
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
|
||||||
|
|
||||||
return multimodal, nil
|
return multimodal, nil
|
||||||
|
|||||||
@@ -37,27 +37,23 @@ type VisionAttention struct {
|
|||||||
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||||
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
|
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
|
||||||
|
|
||||||
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
|
|
||||||
|
|
||||||
// t1 = t[..., 0::2]
|
// t1 = t[..., 0::2]
|
||||||
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
|
||||||
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
|
|
||||||
|
|
||||||
// t2 = t[..., 1::2]
|
// t2 = t[..., 1::2]
|
||||||
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
|
||||||
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
|
|
||||||
|
|
||||||
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
||||||
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
|
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
|
||||||
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
|
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
|
||||||
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
|
||||||
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
|
cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
|
||||||
|
|
||||||
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
||||||
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
||||||
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
|
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
|
||||||
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
|
||||||
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
|
sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
|
||||||
|
|
||||||
return cosOut.Add(ctx, sinOut)
|
return cosOut.Add(ctx, sinOut)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
var batchSize int = 1
|
var batchSize int = 1
|
||||||
|
|
||||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
|
||||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
|
||||||
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ import (
|
|||||||
var batchSize int = 1
|
var batchSize int = 1
|
||||||
|
|
||||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
|
||||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
|
||||||
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ type VisionAttention struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
|
||||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor
|
|||||||
positionEmbeds = positionEmbeds.Mul(ctx, weights)
|
positionEmbeds = positionEmbeds.Mul(ctx, weights)
|
||||||
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
|
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
|
||||||
|
|
||||||
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width).
|
positionEmbedsChunks := positionEmbeds.Chunk(ctx, 2, 1)
|
||||||
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
|
positionEmbeds = positionEmbedsChunks[0].
|
||||||
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
|
Add(ctx, positionEmbedsChunks[1]).
|
||||||
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width))
|
Add(ctx, positionEmbedsChunks[2]).
|
||||||
|
Add(ctx, positionEmbedsChunks[3])
|
||||||
|
|
||||||
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
|
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
|
||||||
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)
|
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)
|
||||||
|
|||||||
Reference in New Issue
Block a user