From d475d1f081e5455dcfdf9e958619223565b9bf52 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 8 Dec 2025 13:17:03 -0800 Subject: [PATCH] fix: qwen2.5vl metal argsort --- ...13-add-argsort-and-cuda-copy-for-i32.patch | 169 +++++++++++++++++- .../patches/0027-interleave-multi-rope.patch | 2 +- .../src/ggml-metal/ggml-metal-embed.metal | 146 +++++++++++++++ .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 146 +++++++++++++++ 4 files changed, 455 insertions(+), 8 deletions(-) diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index 5e5bc110..26c6dca7 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -4,12 +4,12 @@ Date: Thu, 1 May 2025 13:45:12 -0700 Subject: [PATCH] add argsort and cuda copy for i32 --- - ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++ - ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++--- - ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ - ggml/src/ggml-cuda/cpy.cu | 40 +++++++++ - ggml/src/ggml-metal/ggml-metal.metal | 69 +++++++++++++++ - 5 files changed, 268 insertions(+), 12 deletions(-) + ggml/src/ggml-cpu/ops.cpp | 43 ++++++ + ggml/src/ggml-cuda/argsort.cu | 122 +++++++++++++-- + ggml/src/ggml-cuda/cpy-utils.cuh | 6 + + ggml/src/ggml-cuda/cpy.cu | 40 +++++ + ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++ + 5 files changed, 414 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2745fc54e..40666bab6 100644 @@ -292,7 +292,7 @@ index c4ceb4fc5..0e53ecc39 100644 if (can_be_transposed) { ggml_cpy_scalar_cuda diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 73b45c762..aed013a9d 100644 +index 73b45c762..8a6c834d1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4721,8 +4721,77 @@ kernel void kernel_argsort_f32_i32( @@ -373,3 +373,158 @@ index 73b45c762..aed013a9d 100644 typedef void (argsort_merge_t)( constant ggml_metal_kargs_argsort_merge & args, +@@ -4877,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32( + } + } + ++template ++kernel void kernel_argsort_merge_i32_i32( ++ constant ggml_metal_kargs_argsort_merge & args, ++ device const char * src0, ++ device const int32_t * tmp, ++ device int32_t * dst, ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ ushort3 tpitg[[thread_position_in_threadgroup]], ++ ushort3 ntg[[threads_per_threadgroup]]) { ++ ++ const int im = tgpig[0] / args.ne01; ++ const int i01 = tgpig[0] % args.ne01; ++ const int i02 = tgpig[1]; ++ const int i03 = tgpig[2]; ++ ++ const int start = im * (2 * args.len); ++ ++ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); ++ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); ++ ++ const int total = len0 + len1; ++ ++ device const int32_t * tmp0 = tmp + start ++ + i01*args.ne0 ++ + i02*args.ne0*args.ne01 ++ + i03*args.ne0*args.ne01*args.ne02; ++ ++ device const int32_t * tmp1 = tmp0 + args.len; ++ ++ dst += start ++ + i01*args.top_k ++ + i02*args.top_k*args.ne01 ++ + i03*args.top_k*args.ne01*args.ne02; ++ ++ device const int32_t * src0_row = (device const int32_t *)(src0 ++ + args.nb01*i01 ++ + args.nb02*i02 ++ + args.nb03*i03); ++ ++ if (total == 0) { ++ return; ++ } ++ ++ const int chunk = (total + ntg.x - 1) / ntg.x; ++ ++ const int k0 = tpitg.x * chunk; ++ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); ++ ++ if (k0 >= args.top_k) { ++ return; ++ } ++ ++ if (k0 >= total) { ++ return; ++ } ++ ++ int low = k0 > len1 ? k0 - len1 : 0; ++ int high = MIN(k0, len0); ++ ++ // binary-search partition (i, j) such that i + j = k ++ while (low < high) { ++ const int mid = (low + high) >> 1; ++ ++ const int32_t idx0 = tmp0[mid]; ++ const int32_t idx1 = tmp1[k0 - mid - 1]; ++ ++ const int32_t val0 = src0_row[idx0]; ++ const int32_t val1 = src0_row[idx1]; ++ ++ bool take_left; ++ if (order == GGML_SORT_ORDER_ASC) { ++ take_left = (val0 <= val1); ++ } else { ++ take_left = (val0 >= val1); ++ } ++ ++ if (take_left) { ++ low = mid + 1; ++ } else { ++ high = mid; ++ } ++ } ++ ++ int i = low; ++ int j = k0 - i; ++ ++ // keep the merge fronts into registers ++ int32_t idx0 = 0; ++ int32_t val0 = 0.0f; ++ if (i < len0) { ++ idx0 = tmp0[i]; ++ val0 = src0_row[idx0]; ++ } ++ ++ int32_t idx1 = 0; ++ int32_t val1 = 0.0f; ++ if (j < len1) { ++ idx1 = tmp1[j]; ++ val1 = src0_row[idx1]; ++ } ++ ++ for (int k = k0; k < k1; ++k) { ++ int32_t out_idx; ++ ++ if (i >= len0) { ++ while (k < k1) { ++ dst[k++] = tmp1[j++]; ++ } ++ break; ++ } else if (j >= len1) { ++ while (k < k1) { ++ dst[k++] = tmp0[i++]; ++ } ++ break; ++ } else { ++ bool take_left; ++ ++ if (order == GGML_SORT_ORDER_ASC) { ++ take_left = (val0 <= val1); ++ } else { ++ take_left = (val0 >= val1); ++ } ++ ++ if (take_left) { ++ out_idx = idx0; ++ ++i; ++ if (i < len0) { ++ idx0 = tmp0[i]; ++ val0 = src0_row[idx0]; ++ } ++ } else { ++ out_idx = idx1; ++ ++j; ++ if (j < len1) { ++ idx1 = tmp1[j]; ++ val1 = src0_row[idx1]; ++ } ++ } ++ } ++ ++ dst[k] = out_idx; ++ } ++} ++ + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; + template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; ++template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; ++template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; + + kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, diff --git a/llama/patches/0027-interleave-multi-rope.patch b/llama/patches/0027-interleave-multi-rope.patch index 1fee6b75..7d36d355 100644 --- a/llama/patches/0027-interleave-multi-rope.patch +++ b/llama/patches/0027-interleave-multi-rope.patch @@ -59,7 +59,7 @@ index 88ed79111..71ca60214 100644 } else { if (sector < sections.v[0]) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index aed013a9d..a489de435 100644 +index 8a6c834d1..761b57a26 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi( diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index da4c2bb0..9903af36 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -7723,8 +7723,154 @@ kernel void kernel_argsort_merge_f32_i32( } } +template +kernel void kernel_argsort_merge_i32_i32( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); + + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); + + const int total = len0 + len1; + + device const int32_t * tmp0 = tmp + start + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; + + device const int32_t * tmp1 = tmp0 + args.len; + + dst += start + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; + + device const int32_t * src0_row = (device const int32_t *)(src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + if (total == 0) { + return; + } + + const int chunk = (total + ntg.x - 1) / ntg.x; + + const int k0 = tpitg.x * chunk; + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } + + if (k0 >= total) { + return; + } + + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const int32_t val0 = src0_row[idx0]; + const int32_t val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + + int i = low; + int j = k0 - i; + + // keep the merge fronts into registers + int32_t idx0 = 0; + int32_t val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + int32_t val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + + for (int k = k0; k < k1; ++k) { + int32_t out_idx; + + if (i >= len0) { + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; + } else if (j >= len1) { + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; + } else { + bool take_left; + + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + out_idx = idx0; + ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + } else { + out_idx = idx1; + ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + } + } + + dst[k] = out_idx; + } +} + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; kernel void kernel_leaky_relu_f32( constant ggml_metal_kargs_leaky_relu & args, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index a489de43..761b57a2 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -4946,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32( } } +template +kernel void kernel_argsort_merge_i32_i32( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); + + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); + + const int total = len0 + len1; + + device const int32_t * tmp0 = tmp + start + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; + + device const int32_t * tmp1 = tmp0 + args.len; + + dst += start + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; + + device const int32_t * src0_row = (device const int32_t *)(src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + if (total == 0) { + return; + } + + const int chunk = (total + ntg.x - 1) / ntg.x; + + const int k0 = tpitg.x * chunk; + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } + + if (k0 >= total) { + return; + } + + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const int32_t val0 = src0_row[idx0]; + const int32_t val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + + int i = low; + int j = k0 - i; + + // keep the merge fronts into registers + int32_t idx0 = 0; + int32_t val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + int32_t val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + + for (int k = k0; k < k1; ++k) { + int32_t out_idx; + + if (i >= len0) { + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; + } else if (j >= len1) { + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; + } else { + bool take_left; + + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + out_idx = idx0; + ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + } else { + out_idx = idx1; + ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + } + } + + dst[k] = out_idx; + } +} + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; kernel void kernel_leaky_relu_f32( constant ggml_metal_kargs_leaky_relu & args,