mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
fix: qwen2.5vl metal argsort
This commit is contained in:
committed by
Michael Yang
parent
d2f334c1f7
commit
d475d1f081
@@ -4,12 +4,12 @@ Date: Thu, 1 May 2025 13:45:12 -0700
|
|||||||
Subject: [PATCH] add argsort and cuda copy for i32
|
Subject: [PATCH] add argsort and cuda copy for i32
|
||||||
|
|
||||||
---
|
---
|
||||||
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++
|
ggml/src/ggml-cpu/ops.cpp | 43 ++++++
|
||||||
ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++---
|
ggml/src/ggml-cuda/argsort.cu | 122 +++++++++++++--
|
||||||
ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++
|
ggml/src/ggml-cuda/cpy-utils.cuh | 6 +
|
||||||
ggml/src/ggml-cuda/cpy.cu | 40 +++++++++
|
ggml/src/ggml-cuda/cpy.cu | 40 +++++
|
||||||
ggml/src/ggml-metal/ggml-metal.metal | 69 +++++++++++++++
|
ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++
|
||||||
5 files changed, 268 insertions(+), 12 deletions(-)
|
5 files changed, 414 insertions(+), 12 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||||
index 2745fc54e..40666bab6 100644
|
index 2745fc54e..40666bab6 100644
|
||||||
@@ -292,7 +292,7 @@ index c4ceb4fc5..0e53ecc39 100644
|
|||||||
if (can_be_transposed) {
|
if (can_be_transposed) {
|
||||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
index 73b45c762..aed013a9d 100644
|
index 73b45c762..8a6c834d1 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
@@ -4721,8 +4721,77 @@ kernel void kernel_argsort_f32_i32(
|
@@ -4721,8 +4721,77 @@ kernel void kernel_argsort_f32_i32(
|
||||||
@@ -373,3 +373,158 @@ index 73b45c762..aed013a9d 100644
|
|||||||
|
|
||||||
typedef void (argsort_merge_t)(
|
typedef void (argsort_merge_t)(
|
||||||
constant ggml_metal_kargs_argsort_merge & args,
|
constant ggml_metal_kargs_argsort_merge & args,
|
||||||
|
@@ -4877,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
+template<ggml_sort_order order>
|
||||||
|
+kernel void kernel_argsort_merge_i32_i32(
|
||||||
|
+ constant ggml_metal_kargs_argsort_merge & args,
|
||||||
|
+ device const char * src0,
|
||||||
|
+ device const int32_t * tmp,
|
||||||
|
+ device int32_t * dst,
|
||||||
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
+
|
||||||
|
+ const int im = tgpig[0] / args.ne01;
|
||||||
|
+ const int i01 = tgpig[0] % args.ne01;
|
||||||
|
+ const int i02 = tgpig[1];
|
||||||
|
+ const int i03 = tgpig[2];
|
||||||
|
+
|
||||||
|
+ const int start = im * (2 * args.len);
|
||||||
|
+
|
||||||
|
+ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||||
|
+ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||||
|
+
|
||||||
|
+ const int total = len0 + len1;
|
||||||
|
+
|
||||||
|
+ device const int32_t * tmp0 = tmp + start
|
||||||
|
+ + i01*args.ne0
|
||||||
|
+ + i02*args.ne0*args.ne01
|
||||||
|
+ + i03*args.ne0*args.ne01*args.ne02;
|
||||||
|
+
|
||||||
|
+ device const int32_t * tmp1 = tmp0 + args.len;
|
||||||
|
+
|
||||||
|
+ dst += start
|
||||||
|
+ + i01*args.top_k
|
||||||
|
+ + i02*args.top_k*args.ne01
|
||||||
|
+ + i03*args.top_k*args.ne01*args.ne02;
|
||||||
|
+
|
||||||
|
+ device const int32_t * src0_row = (device const int32_t *)(src0
|
||||||
|
+ + args.nb01*i01
|
||||||
|
+ + args.nb02*i02
|
||||||
|
+ + args.nb03*i03);
|
||||||
|
+
|
||||||
|
+ if (total == 0) {
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||||
|
+
|
||||||
|
+ const int k0 = tpitg.x * chunk;
|
||||||
|
+ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||||
|
+
|
||||||
|
+ if (k0 >= args.top_k) {
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (k0 >= total) {
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int low = k0 > len1 ? k0 - len1 : 0;
|
||||||
|
+ int high = MIN(k0, len0);
|
||||||
|
+
|
||||||
|
+ // binary-search partition (i, j) such that i + j = k
|
||||||
|
+ while (low < high) {
|
||||||
|
+ const int mid = (low + high) >> 1;
|
||||||
|
+
|
||||||
|
+ const int32_t idx0 = tmp0[mid];
|
||||||
|
+ const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||||
|
+
|
||||||
|
+ const int32_t val0 = src0_row[idx0];
|
||||||
|
+ const int32_t val1 = src0_row[idx1];
|
||||||
|
+
|
||||||
|
+ bool take_left;
|
||||||
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
+ take_left = (val0 <= val1);
|
||||||
|
+ } else {
|
||||||
|
+ take_left = (val0 >= val1);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (take_left) {
|
||||||
|
+ low = mid + 1;
|
||||||
|
+ } else {
|
||||||
|
+ high = mid;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int i = low;
|
||||||
|
+ int j = k0 - i;
|
||||||
|
+
|
||||||
|
+ // keep the merge fronts into registers
|
||||||
|
+ int32_t idx0 = 0;
|
||||||
|
+ int32_t val0 = 0.0f;
|
||||||
|
+ if (i < len0) {
|
||||||
|
+ idx0 = tmp0[i];
|
||||||
|
+ val0 = src0_row[idx0];
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int32_t idx1 = 0;
|
||||||
|
+ int32_t val1 = 0.0f;
|
||||||
|
+ if (j < len1) {
|
||||||
|
+ idx1 = tmp1[j];
|
||||||
|
+ val1 = src0_row[idx1];
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ for (int k = k0; k < k1; ++k) {
|
||||||
|
+ int32_t out_idx;
|
||||||
|
+
|
||||||
|
+ if (i >= len0) {
|
||||||
|
+ while (k < k1) {
|
||||||
|
+ dst[k++] = tmp1[j++];
|
||||||
|
+ }
|
||||||
|
+ break;
|
||||||
|
+ } else if (j >= len1) {
|
||||||
|
+ while (k < k1) {
|
||||||
|
+ dst[k++] = tmp0[i++];
|
||||||
|
+ }
|
||||||
|
+ break;
|
||||||
|
+ } else {
|
||||||
|
+ bool take_left;
|
||||||
|
+
|
||||||
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
+ take_left = (val0 <= val1);
|
||||||
|
+ } else {
|
||||||
|
+ take_left = (val0 >= val1);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (take_left) {
|
||||||
|
+ out_idx = idx0;
|
||||||
|
+ ++i;
|
||||||
|
+ if (i < len0) {
|
||||||
|
+ idx0 = tmp0[i];
|
||||||
|
+ val0 = src0_row[idx0];
|
||||||
|
+ }
|
||||||
|
+ } else {
|
||||||
|
+ out_idx = idx1;
|
||||||
|
+ ++j;
|
||||||
|
+ if (j < len1) {
|
||||||
|
+ idx1 = tmp1[j];
|
||||||
|
+ val1 = src0_row[idx1];
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ dst[k] = out_idx;
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
+template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
|
+template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
|
||||||
|
kernel void kernel_leaky_relu_f32(
|
||||||
|
constant ggml_metal_kargs_leaky_relu & args,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ index 88ed79111..71ca60214 100644
|
|||||||
} else {
|
} else {
|
||||||
if (sector < sections.v[0]) {
|
if (sector < sections.v[0]) {
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
index aed013a9d..a489de435 100644
|
index 8a6c834d1..761b57a26 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
@@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi(
|
@@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi(
|
||||||
|
|||||||
@@ -7723,8 +7723,154 @@ kernel void kernel_argsort_merge_f32_i32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<ggml_sort_order order>
|
||||||
|
kernel void kernel_argsort_merge_i32_i32(
|
||||||
|
constant ggml_metal_kargs_argsort_merge & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const int32_t * tmp,
|
||||||
|
device int32_t * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int im = tgpig[0] / args.ne01;
|
||||||
|
const int i01 = tgpig[0] % args.ne01;
|
||||||
|
const int i02 = tgpig[1];
|
||||||
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
|
const int start = im * (2 * args.len);
|
||||||
|
|
||||||
|
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||||
|
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||||
|
|
||||||
|
const int total = len0 + len1;
|
||||||
|
|
||||||
|
device const int32_t * tmp0 = tmp + start
|
||||||
|
+ i01*args.ne0
|
||||||
|
+ i02*args.ne0*args.ne01
|
||||||
|
+ i03*args.ne0*args.ne01*args.ne02;
|
||||||
|
|
||||||
|
device const int32_t * tmp1 = tmp0 + args.len;
|
||||||
|
|
||||||
|
dst += start
|
||||||
|
+ i01*args.top_k
|
||||||
|
+ i02*args.top_k*args.ne01
|
||||||
|
+ i03*args.top_k*args.ne01*args.ne02;
|
||||||
|
|
||||||
|
device const int32_t * src0_row = (device const int32_t *)(src0
|
||||||
|
+ args.nb01*i01
|
||||||
|
+ args.nb02*i02
|
||||||
|
+ args.nb03*i03);
|
||||||
|
|
||||||
|
if (total == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||||
|
|
||||||
|
const int k0 = tpitg.x * chunk;
|
||||||
|
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||||
|
|
||||||
|
if (k0 >= args.top_k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (k0 >= total) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int low = k0 > len1 ? k0 - len1 : 0;
|
||||||
|
int high = MIN(k0, len0);
|
||||||
|
|
||||||
|
// binary-search partition (i, j) such that i + j = k
|
||||||
|
while (low < high) {
|
||||||
|
const int mid = (low + high) >> 1;
|
||||||
|
|
||||||
|
const int32_t idx0 = tmp0[mid];
|
||||||
|
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||||
|
|
||||||
|
const int32_t val0 = src0_row[idx0];
|
||||||
|
const int32_t val1 = src0_row[idx1];
|
||||||
|
|
||||||
|
bool take_left;
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
take_left = (val0 <= val1);
|
||||||
|
} else {
|
||||||
|
take_left = (val0 >= val1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (take_left) {
|
||||||
|
low = mid + 1;
|
||||||
|
} else {
|
||||||
|
high = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = low;
|
||||||
|
int j = k0 - i;
|
||||||
|
|
||||||
|
// keep the merge fronts into registers
|
||||||
|
int32_t idx0 = 0;
|
||||||
|
int32_t val0 = 0.0f;
|
||||||
|
if (i < len0) {
|
||||||
|
idx0 = tmp0[i];
|
||||||
|
val0 = src0_row[idx0];
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t idx1 = 0;
|
||||||
|
int32_t val1 = 0.0f;
|
||||||
|
if (j < len1) {
|
||||||
|
idx1 = tmp1[j];
|
||||||
|
val1 = src0_row[idx1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = k0; k < k1; ++k) {
|
||||||
|
int32_t out_idx;
|
||||||
|
|
||||||
|
if (i >= len0) {
|
||||||
|
while (k < k1) {
|
||||||
|
dst[k++] = tmp1[j++];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else if (j >= len1) {
|
||||||
|
while (k < k1) {
|
||||||
|
dst[k++] = tmp0[i++];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
bool take_left;
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
take_left = (val0 <= val1);
|
||||||
|
} else {
|
||||||
|
take_left = (val0 >= val1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (take_left) {
|
||||||
|
out_idx = idx0;
|
||||||
|
++i;
|
||||||
|
if (i < len0) {
|
||||||
|
idx0 = tmp0[i];
|
||||||
|
val0 = src0_row[idx0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_idx = idx1;
|
||||||
|
++j;
|
||||||
|
if (j < len1) {
|
||||||
|
idx1 = tmp1[j];
|
||||||
|
val1 = src0_row[idx1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[k] = out_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
|
template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
|
||||||
kernel void kernel_leaky_relu_f32(
|
kernel void kernel_leaky_relu_f32(
|
||||||
constant ggml_metal_kargs_leaky_relu & args,
|
constant ggml_metal_kargs_leaky_relu & args,
|
||||||
|
|||||||
146
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
146
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
@@ -4946,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<ggml_sort_order order>
|
||||||
|
kernel void kernel_argsort_merge_i32_i32(
|
||||||
|
constant ggml_metal_kargs_argsort_merge & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const int32_t * tmp,
|
||||||
|
device int32_t * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int im = tgpig[0] / args.ne01;
|
||||||
|
const int i01 = tgpig[0] % args.ne01;
|
||||||
|
const int i02 = tgpig[1];
|
||||||
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
|
const int start = im * (2 * args.len);
|
||||||
|
|
||||||
|
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||||
|
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||||
|
|
||||||
|
const int total = len0 + len1;
|
||||||
|
|
||||||
|
device const int32_t * tmp0 = tmp + start
|
||||||
|
+ i01*args.ne0
|
||||||
|
+ i02*args.ne0*args.ne01
|
||||||
|
+ i03*args.ne0*args.ne01*args.ne02;
|
||||||
|
|
||||||
|
device const int32_t * tmp1 = tmp0 + args.len;
|
||||||
|
|
||||||
|
dst += start
|
||||||
|
+ i01*args.top_k
|
||||||
|
+ i02*args.top_k*args.ne01
|
||||||
|
+ i03*args.top_k*args.ne01*args.ne02;
|
||||||
|
|
||||||
|
device const int32_t * src0_row = (device const int32_t *)(src0
|
||||||
|
+ args.nb01*i01
|
||||||
|
+ args.nb02*i02
|
||||||
|
+ args.nb03*i03);
|
||||||
|
|
||||||
|
if (total == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||||
|
|
||||||
|
const int k0 = tpitg.x * chunk;
|
||||||
|
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||||
|
|
||||||
|
if (k0 >= args.top_k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (k0 >= total) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int low = k0 > len1 ? k0 - len1 : 0;
|
||||||
|
int high = MIN(k0, len0);
|
||||||
|
|
||||||
|
// binary-search partition (i, j) such that i + j = k
|
||||||
|
while (low < high) {
|
||||||
|
const int mid = (low + high) >> 1;
|
||||||
|
|
||||||
|
const int32_t idx0 = tmp0[mid];
|
||||||
|
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||||
|
|
||||||
|
const int32_t val0 = src0_row[idx0];
|
||||||
|
const int32_t val1 = src0_row[idx1];
|
||||||
|
|
||||||
|
bool take_left;
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
take_left = (val0 <= val1);
|
||||||
|
} else {
|
||||||
|
take_left = (val0 >= val1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (take_left) {
|
||||||
|
low = mid + 1;
|
||||||
|
} else {
|
||||||
|
high = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = low;
|
||||||
|
int j = k0 - i;
|
||||||
|
|
||||||
|
// keep the merge fronts into registers
|
||||||
|
int32_t idx0 = 0;
|
||||||
|
int32_t val0 = 0.0f;
|
||||||
|
if (i < len0) {
|
||||||
|
idx0 = tmp0[i];
|
||||||
|
val0 = src0_row[idx0];
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t idx1 = 0;
|
||||||
|
int32_t val1 = 0.0f;
|
||||||
|
if (j < len1) {
|
||||||
|
idx1 = tmp1[j];
|
||||||
|
val1 = src0_row[idx1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = k0; k < k1; ++k) {
|
||||||
|
int32_t out_idx;
|
||||||
|
|
||||||
|
if (i >= len0) {
|
||||||
|
while (k < k1) {
|
||||||
|
dst[k++] = tmp1[j++];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else if (j >= len1) {
|
||||||
|
while (k < k1) {
|
||||||
|
dst[k++] = tmp0[i++];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
bool take_left;
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
take_left = (val0 <= val1);
|
||||||
|
} else {
|
||||||
|
take_left = (val0 >= val1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (take_left) {
|
||||||
|
out_idx = idx0;
|
||||||
|
++i;
|
||||||
|
if (i < len0) {
|
||||||
|
idx0 = tmp0[i];
|
||||||
|
val0 = src0_row[idx0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_idx = idx1;
|
||||||
|
++j;
|
||||||
|
if (j < len1) {
|
||||||
|
idx1 = tmp1[j];
|
||||||
|
val1 = src0_row[idx1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[k] = out_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
|
template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
|
||||||
kernel void kernel_leaky_relu_f32(
|
kernel void kernel_leaky_relu_f32(
|
||||||
constant ggml_metal_kargs_leaky_relu & args,
|
constant ggml_metal_kargs_leaky_relu & args,
|
||||||
|
|||||||
Reference in New Issue
Block a user