mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
* Enable CUDA Graphs for gemma3n. Similar to https://github.com/ggml-org/llama.cpp/pull/14741, though ollama has a slightly different model graph than llama.cpp which requires different workaround checks. * Remove residual check by reshaping differently in gemma3n model This should make the heuristics more robust
170 lines
6.9 KiB
Diff
170 lines
6.9 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: Georgi Gerganov <ggerganov@gmail.com>
|
|
Date: Thu, 19 Jun 2025 08:05:21 +0300
|
|
Subject: [PATCH] metal : add mean kernel (#14267)
|
|
|
|
* metal : add mean kernel
|
|
|
|
ggml-ci
|
|
|
|
* cont : dedup implementation
|
|
|
|
ggml-ci
|
|
---
|
|
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
|
|
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
|
|
2 files changed, 67 insertions(+), 14 deletions(-)
|
|
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
|
index a9eeebc6..110c9ece 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
|
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
case GGML_OP_LOG:
|
|
return false; // TODO: implement
|
|
case GGML_OP_SUM_ROWS:
|
|
+ case GGML_OP_MEAN:
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_GROUP_NORM:
|
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
} break;
|
|
case GGML_OP_SUM_ROWS:
|
|
+ case GGML_OP_MEAN:
|
|
{
|
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
|
|
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
+
|
|
+ switch (dst->op) {
|
|
+ case GGML_OP_SUM_ROWS:
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
+ break;
|
|
+ case GGML_OP_MEAN:
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
+ break;
|
|
+ default:
|
|
+ GGML_ABORT("fatal error");
|
|
+ }
|
|
+
|
|
+ int nth = 32; // SIMD width
|
|
+
|
|
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
+ nth *= 2;
|
|
+ }
|
|
|
|
+ nth = MIN(nth, ne00);
|
|
|
|
ggml_metal_kargs_sum_rows args = {
|
|
/*.ne00 =*/ ne00,
|
|
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
|
};
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
} break;
|
|
case GGML_OP_SOFT_MAX:
|
|
{
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
|
index 9cfddf45..08e8d807 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
|
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
|
dst[tpig] = -src0[tpig];
|
|
}
|
|
|
|
+template <bool norm>
|
|
kernel void kernel_sum_rows(
|
|
+ constant ggml_metal_kargs_sum_rows & args,
|
|
device const float * src0,
|
|
device float * dst,
|
|
- constant ggml_metal_kargs_sum_rows & args,
|
|
- uint3 tpig[[thread_position_in_grid]]) {
|
|
- int64_t i3 = tpig.z;
|
|
- int64_t i2 = tpig.y;
|
|
- int64_t i1 = tpig.x;
|
|
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
+ int64_t i3 = tgpig.z;
|
|
+ int64_t i2 = tgpig.y;
|
|
+ int64_t i1 = tgpig.x;
|
|
|
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
return;
|
|
}
|
|
|
|
+ if (sgitg == 0) {
|
|
+ shmem_f32[tiisg] = 0.0f;
|
|
+ }
|
|
+
|
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
|
|
- float row_sum = 0;
|
|
+ float sumf = 0;
|
|
|
|
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
|
- row_sum += src_row[i0];
|
|
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
+ sumf += src_row[i0];
|
|
}
|
|
|
|
- dst_row[0] = row_sum;
|
|
+ sumf = simd_sum(sumf);
|
|
+
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+
|
|
+ if (tiisg == 0) {
|
|
+ shmem_f32[sgitg] = sumf;
|
|
+ }
|
|
+
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+
|
|
+ sumf = shmem_f32[tiisg];
|
|
+ sumf = simd_sum(sumf);
|
|
+
|
|
+ if (tpitg.x == 0) {
|
|
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
|
+ }
|
|
}
|
|
|
|
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
|
+
|
|
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
|
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
|
+
|
|
template<typename T>
|
|
kernel void kernel_soft_max(
|
|
device const char * src0,
|