Files
ollama-for-amd/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch
2025-11-06 10:19:22 -08:00

371 lines
16 KiB
Diff

From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Thu, 1 May 2025 13:45:12 -0700
Subject: [PATCH] add argsort and cuda copy for i32
---
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++
ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++---
ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++
ggml/src/ggml-cuda/cpy.cu | 40 +++++++++
ggml/src/ggml-metal/ggml-metal.metal | 64 ++++++++++++++
5 files changed, 263 insertions(+), 12 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index b52f0f847..902fdad69 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
+static void ggml_compute_forward_argsort_i32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(int32_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ }
+}
+
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -7900,6 +7939,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_argsort_i32(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 6e7b90d42..08dd30525 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
}
}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
+ extern __shared__ int shared_mem[];
+ int * indices = shared_mem;
+
+ const int tid = threadIdx.x;
+ const int row = blockIdx.y;
+
+ // Initialize all indices, handling the case where threads < ncols_pad
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
+ }
+ __syncthreads();
+
+ // Bitonic sort
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k/2; j > 0; j /= 2) {
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ const int ij = i ^ j;
+ if (ij > i) {
+ // Only compare values within the actual data range
+ if (i < ncols && ij < ncols) {
+ if ((i & k) == 0) {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ } else {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ }
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // Write sorted indices to output, only threads handling valid data
+ for (int i = tid; i < ncols; i += blockDim.x) {
+ dst[row * ncols + i] = indices[i];
+ }
+}
+
+static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+ // Bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ // Ensure thread count doesn't exceed maximum (typically 1024)
+ const int max_threads = 1024; // This is the typical max for most GPUs
+ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
+
+ const dim3 block_dims(threads_per_block, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ // Check if shared memory size is within limits
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ // Instead of logging an error, use GGML_ASSERT with a descriptive message
+ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
+
+ // Launch kernels with the updated thread configuration
+ if (order == GGML_SORT_ORDER_ASC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+}
+
+
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-#ifdef GGML_CUDA_USE_CUB
- const int ncols_pad = next_power_of_2(ncols);
- const size_t shared_mem = ncols_pad * sizeof(int);
- const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
-
- if (shared_mem > max_shared_mem || ncols > 1024) {
- ggml_cuda_pool & pool = ctx.pool();
- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ if (src0->type == GGML_TYPE_I32) {
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else {
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
- }
+#ifdef GGML_CUDA_USE_CUB
+ const int ncols_pad = next_power_of_2(ncols);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ if (shared_mem > max_shared_mem || ncols > 1024) {
+ ggml_cuda_pool & pool = ctx.pool();
+ argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ } else {
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ }
#else
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
index e621cb981..597c0c8b3 100644
--- a/ggml/src/ggml-cuda/cpy-utils.cuh
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh
@@ -215,3 +215,9 @@ template<typename src_t, typename dst_t>
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
}
+
+static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+ const int32_t * src = (const int32_t *)cxi;
+ int32_t * dst = (int32_t *)cdsti;
+ *dst = *src;
+}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 12d5bf776..a0e34030e 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -251,6 +251,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_i32_i32(
+ const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
+ const int64_t x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
+
+ const int64_t i13 = i / (ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
+ const int64_t i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
+ const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
+ const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+static void ggml_cpy_i32_i32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
+}
+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -332,6 +369,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ // TODO consider converting to template
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 2c2f01415..50b8071de 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -4467,8 +4467,72 @@ kernel void kernel_argsort_f32_i32(
}
}
+typedef void (i32_argsort_t)(
+ constant ggml_metal_kargs_argsort & args,
+ device const int32_t * x,
+ device int32_t * dst,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_i32_i32(
+ constant ggml_metal_kargs_argsort & args,
+ device const int32_t * x,
+ device int32_t * dst,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= args.ncols_pad) return;
+
+ device const int32_t * x_row = x + row * args.ncols;
+ threadgroup int32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= args.ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= args.ncols ||
+ (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= args.ncols ||
+ (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < args.ncols) {
+ dst[row * args.ncols + col] = dst_row[col];
+ }
+}
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
+template [[host_name("kernel_argsort_i32_i32_asc")]] kernel i32_argsort_t kernel_argsort_i32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_i32_i32_desc")]] kernel i32_argsort_t kernel_argsort_i32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,