interleaved mrope (#12807)

* ml(ggml): mrope
* interleave mrope
This commit is contained in:
Michael Yang
2025-10-30 11:29:00 -07:00
committed by GitHub
parent 75e75d9afe
commit f67a6df110
10 changed files with 209 additions and 119 deletions

View File

@@ -11,6 +11,7 @@ package ggml
import "C"
import (
"cmp"
"context"
"encoding/binary"
"errors"
@@ -1490,14 +1491,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
// Default options
opts := rope.Options{
Factors: &Tensor{},
OriginalContextLength: 131072,
ExtrapolationFactor: 0.,
AttentionFactor: 1.,
BetaFast: 32.,
BetaSlow: 1.,
}
opts := rope.Options{Factors: &Tensor{}}
// Apply any provided options
for _, option := range options {
@@ -1509,24 +1503,44 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
var tt *C.struct_ggml_tensor
if len(opts.MRoPE.Sections) > 0 {
mropeSections := make([]C.int32_t, 4)
for i, section := range opts.MRoPE.Sections {
mropeSections[i] = C.int32_t(section)
}
tt = C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positions.(*Tensor).t,
opts.Factors.(*Tensor).t,
C.int(ropeDim),
unsafe.SliceData(mropeSections),
C.int(opts.Type),
C.int(opts.OriginalContextLength),
C.float(ropeBase),
C.float(ropeScale),
C.float(opts.ExtrapolationFactor),
C.float(opts.AttentionFactor),
C.float(opts.BetaFast),
C.float(opts.BetaSlow),
),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
)
} else {
tt = C.ggml_rope_ext(
ctx.(*Context).ctx,
dequant,
positions.(*Tensor).t,
opts.Factors.(*Tensor).t,
C.int(ropeDim), C.int(opts.Type),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {