interleaved mrope (#12807)

* ml(ggml): mrope
* interleave mrope
This commit is contained in:
Michael Yang
2025-10-30 11:29:00 -07:00
committed by GitHub
parent 75e75d9afe
commit f67a6df110
10 changed files with 209 additions and 119 deletions

View File

@@ -11,6 +11,7 @@ package ggml
import "C"
import (
"cmp"
"context"
"encoding/binary"
"errors"
@@ -1490,14 +1491,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
// Default options
opts := rope.Options{
Factors: &Tensor{},
OriginalContextLength: 131072,
ExtrapolationFactor: 0.,
AttentionFactor: 1.,
BetaFast: 32.,
BetaSlow: 1.,
}
opts := rope.Options{Factors: &Tensor{}}
// Apply any provided options
for _, option := range options {
@@ -1509,24 +1503,44 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
var tt *C.struct_ggml_tensor
if len(opts.MRoPE.Sections) > 0 {
mropeSections := make([]C.int32_t, 4)
for i, section := range opts.MRoPE.Sections {
mropeSections[i] = C.int32_t(section)
}
tt = C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positions.(*Tensor).t,
opts.Factors.(*Tensor).t,
C.int(ropeDim),
unsafe.SliceData(mropeSections),
C.int(opts.Type),
C.int(opts.OriginalContextLength),
C.float(ropeBase),
C.float(ropeScale),
C.float(opts.ExtrapolationFactor),
C.float(opts.AttentionFactor),
C.float(opts.BetaFast),
C.float(opts.BetaSlow),
),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
)
} else {
tt = C.ggml_rope_ext(
ctx.(*Context).ctx,
dequant,
positions.(*Tensor).t,
opts.Factors.(*Tensor).t,
C.int(ropeDim), C.int(opts.Type),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {

View File

@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init(
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]

View File

@@ -151,19 +151,13 @@ static __global__ void rope_multi(
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

View File

@@ -6523,15 +6523,11 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
float theta_base = (float) pos[i2];
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
} else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope

View File

@@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
float theta_base = (float) pos[i2];
if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
} else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope

View File

@@ -31,19 +31,13 @@ void main() {
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View File

@@ -4,21 +4,21 @@ import "github.com/ollama/ollama/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
OriginalContextLength int
Type int
Factors ml.Tensor
// YaRN options
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.OriginalContextLength = n
// MRoPE options
MRoPE struct {
Sections []int
}
}
@@ -38,14 +38,28 @@ func WithFactors(factors ml.Tensor) func(*Options) {
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.ExtrapolationFactor = extrapolationFactor
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.AttentionFactor = attentionFactor
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPESections(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}