From 31b8c6a214dbfd7f1c711869572301ef7bf41b58 Mon Sep 17 00:00:00 2001 From: Sos Pogosyan <55689991+ZeeeUs@users.noreply.github.com> Date: Fri, 5 Dec 2025 08:33:07 +0300 Subject: [PATCH 01/35] fix(api): correct Content-Type header for /api/chat and /api/generate when using cloud models (#13279) --------- Co-authored-by: Pogosyan Sos Co-authored-by: Patrick Devine --- server/routes.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/server/routes.go b/server/routes.go index 16df3f4f..e5e6dd5a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -262,6 +262,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead") } + contentType := "application/x-ndjson" + if req.Stream != nil && !*req.Stream { + contentType = "application/json; charset=utf-8" + } + c.Header("Content-Type", contentType) + fn := func(resp api.GenerateResponse) error { resp.Model = origModel resp.RemoteModel = m.Config.RemoteModel @@ -303,12 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - contentType := "application/json; charset=utf-8" - if req.Stream != nil && *req.Stream { - contentType = "application/x-ndjson" - } - c.Header("Content-Type", contentType) - return } @@ -1939,6 +1939,12 @@ func (s *Server) ChatHandler(c *gin.Context) { } } + contentType := "application/x-ndjson" + if req.Stream != nil && !*req.Stream { + contentType = "application/json; charset=utf-8" + } + c.Header("Content-Type", contentType) + fn := func(resp api.ChatResponse) error { resp.Model = origModel resp.RemoteModel = m.Config.RemoteModel @@ -1980,12 +1986,6 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - contentType := "application/json; charset=utf-8" - if req.Stream != nil && *req.Stream { - contentType = "application/x-ndjson" - } - c.Header("Content-Type", contentType) - return } From c146a138e35520cd7ee132da46c6ae185778b4f0 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 5 Dec 2025 16:10:33 -0800 Subject: [PATCH 02/35] ggml: handle all streams (#13350) Follow up from #12992 Free all streams, and keep the alloc logic aligned across streams. --- llama/patches/0020-ggml-No-alloc-mode.patch | 42 ++++++++++++------- ...gml-Enable-resetting-backend-devices.patch | 8 ++-- .../0024-GPU-discovery-enhancements.patch | 14 +++---- .../0029-ggml-cuda-skip-large-batches.patch | 4 +- ml/backend/ggml/ggml/src/ggml-cuda/common.cuh | 12 ++++-- .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++- 6 files changed, 55 insertions(+), 35 deletions(-) diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0020-ggml-No-alloc-mode.patch index 0dff5573..01a42690 100644 --- a/llama/patches/0020-ggml-No-alloc-mode.patch +++ b/llama/patches/0020-ggml-No-alloc-mode.patch @@ -10,10 +10,10 @@ must be recreated with no-alloc set to false before loading data. --- ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 16 +++ - ggml/src/ggml-backend.cpp | 72 ++++++++++- - ggml/src/ggml-cuda/common.cuh | 58 ++++++++- - ggml/src/ggml-cuda/ggml-cuda.cu | 218 ++++++++++++++++++++++++++------ - 5 files changed, 321 insertions(+), 44 deletions(-) + ggml/src/ggml-backend.cpp | 72 +++++++++- + ggml/src/ggml-cuda/common.cuh | 62 ++++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------ + 5 files changed, 331 insertions(+), 44 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 2763f2bd6..b3b5b356a 100644 @@ -219,7 +219,7 @@ index f511e8d76..74b7f070c 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index 611341deb..c3f8ca914 100644 +index 611341deb..ee463af9c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -37,6 +37,41 @@ @@ -274,7 +274,7 @@ index 611341deb..c3f8ca914 100644 }; template -@@ -1179,11 +1217,11 @@ struct ggml_backend_cuda_context { +@@ -1179,11 +1217,15 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; @@ -284,11 +284,15 @@ index 611341deb..c3f8ca914 100644 ggml_cuda_pool & pool(int device) { if (pools[device][curr_stream_no] == nullptr) { - pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); -+ pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, true); ++ bool alloc = true; ++ if (pools[device][0] != nullptr) { ++ alloc = pools[device][0]->alloc_memory(); ++ } ++ pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, alloc); } return *pools[device][curr_stream_no]; } -@@ -1191,6 +1229,22 @@ struct ggml_backend_cuda_context { +@@ -1191,6 +1233,22 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -301,18 +305,18 @@ index 611341deb..c3f8ca914 100644 + } + } + -+ size_t pool_get_alloc_size() { -+ if (pools[device][curr_stream_no] == nullptr) { ++ size_t pool_get_alloc_size(int stream_no) { ++ if (pools[device][stream_no] == nullptr) { + return 0; + } + -+ return pools[device][curr_stream_no]->alloc_size(); ++ return pools[device][stream_no]->alloc_size(); + } }; struct ggml_cuda_mm_fusion_args_host { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 78fb2d8b3..fe0da71ca 100644 +index 78fb2d8b3..f1c178f31 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -361,6 +361,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { @@ -583,7 +587,7 @@ index 78fb2d8b3..fe0da71ca 100644 ggml_cuda_set_device(cuda_ctx->device); -@@ -3766,6 +3836,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, +@@ -3766,6 +3836,77 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } @@ -644,18 +648,24 @@ index 78fb2d8b3..fe0da71ca 100644 + +static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; -+ return ctx->pool_get_alloc_size(); ++ size_t allocs = 0; ++ for (int i = 0; i < GGML_CUDA_MAX_STREAMS; i++) { ++ allocs += ctx->pool_get_alloc_size(i); ++ } ++ return allocs; +} + +static void ggml_backend_cuda_reset(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; -+ ctx->pools[ctx->device][ctx->curr_stream_no] = NULL; ++ for (int i = 0; i < GGML_CUDA_MAX_STREAMS; i++) { ++ ctx->pools[ctx->device][i] = NULL; ++ } +} + static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; -@@ -4035,6 +4170,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { +@@ -4035,6 +4176,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch index c65d84f7..04a6b0be 100644 --- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch @@ -62,7 +62,7 @@ index 74b7f070c..8d2cc167f 100644 GGML_ASSERT(device); return device->iface.get_buffer_type(device); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index fe0da71ca..0787e443c 100644 +index f1c178f31..1110ca372 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -109,6 +109,11 @@ int ggml_cuda_get_device() { @@ -77,7 +77,7 @@ index fe0da71ca..0787e443c 100644 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; -@@ -4380,7 +4385,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4386,7 +4391,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); @@ -89,7 +89,7 @@ index fe0da71ca..0787e443c 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY -@@ -4835,6 +4843,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -4841,6 +4849,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -101,7 +101,7 @@ index fe0da71ca..0787e443c 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -4851,6 +4864,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -4857,6 +4870,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0024-GPU-discovery-enhancements.patch index c372f0bc..e4cebfae 100644 --- a/llama/patches/0024-GPU-discovery-enhancements.patch +++ b/llama/patches/0024-GPU-discovery-enhancements.patch @@ -58,7 +58,7 @@ index 6d493a4ff..ac8f38464 100644 set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 0787e443c..736d47c07 100644 +index 1110ca372..c1bfadb3e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -263,6 +263,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -90,7 +90,7 @@ index 0787e443c..736d47c07 100644 GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); -@@ -4249,6 +4264,11 @@ struct ggml_backend_cuda_device_context { +@@ -4255,6 +4270,11 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; std::string id; @@ -102,7 +102,7 @@ index 0787e443c..736d47c07 100644 }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -4345,6 +4365,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { +@@ -4351,6 +4371,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); @@ -131,7 +131,7 @@ index 0787e443c..736d47c07 100644 CUDA_CHECK(cudaMemGetInfo(free, total)); // ref: https://github.com/ggml-org/llama.cpp/pull/17368 -@@ -4377,6 +4419,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend +@@ -4383,6 +4425,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } @@ -139,7 +139,7 @@ index 0787e443c..736d47c07 100644 static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; -@@ -4390,6 +4433,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4396,6 +4439,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; @@ -159,7 +159,7 @@ index 0787e443c..736d47c07 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; -@@ -4974,6 +5030,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4980,6 +5036,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; @@ -167,7 +167,7 @@ index 0787e443c..736d47c07 100644 for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; -@@ -4989,6 +5046,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4995,6 +5052,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; diff --git a/llama/patches/0029-ggml-cuda-skip-large-batches.patch b/llama/patches/0029-ggml-cuda-skip-large-batches.patch index d1d1addd..834b6e9d 100644 --- a/llama/patches/0029-ggml-cuda-skip-large-batches.patch +++ b/llama/patches/0029-ggml-cuda-skip-large-batches.patch @@ -10,10 +10,10 @@ fallback to cpu 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 736d47c07..7350f6758 100644 +index c1bfadb3e..16c166a08 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -4564,6 +4564,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g +@@ -4570,6 +4570,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index c3f8ca91..ee463af9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -1221,7 +1221,11 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool(int device) { if (pools[device][curr_stream_no] == nullptr) { - pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, true); + bool alloc = true; + if (pools[device][0] != nullptr) { + alloc = pools[device][0]->alloc_memory(); + } + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, alloc); } return *pools[device][curr_stream_no]; } @@ -1238,12 +1242,12 @@ struct ggml_backend_cuda_context { } } - size_t pool_get_alloc_size() { - if (pools[device][curr_stream_no] == nullptr) { + size_t pool_get_alloc_size(int stream_no) { + if (pools[device][stream_no] == nullptr) { return 0; } - return pools[device][curr_stream_no]->alloc_size(); + return pools[device][stream_no]->alloc_size(); } }; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index 7350f675..16c166a0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3913,12 +3913,18 @@ static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; - return ctx->pool_get_alloc_size(); + size_t allocs = 0; + for (int i = 0; i < GGML_CUDA_MAX_STREAMS; i++) { + allocs += ctx->pool_get_alloc_size(i); + } + return allocs; } static void ggml_backend_cuda_reset(ggml_backend_t backend) { ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; - ctx->pools[ctx->device][ctx->curr_stream_no] = NULL; + for (int i = 0; i < GGML_CUDA_MAX_STREAMS; i++) { + ctx->pools[ctx->device][i] = NULL; + } } static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { From 5a41d69b2ab7d133a27e4f6d5666982c73a5b5ad Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sun, 7 Dec 2025 21:49:14 -0800 Subject: [PATCH 03/35] fs/ggml: write int32 and int64 values to gguf files (#13335) --- fs/ggml/gguf.go | 8 ++++++++ fs/ggml/gguf_test.go | 14 +++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index b694dead..e093efea 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -597,6 +597,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error { var err error switch v := v.(type) { + case int32: + err = writeGGUF(ws, ggufTypeInt32, v) + case int64: + err = writeGGUF(ws, ggufTypeInt64, v) case uint32, FileType: err = writeGGUF(ws, ggufTypeUint32, v) case uint64: @@ -611,6 +615,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error { err = writeGGUFArray(ws, ggufTypeInt32, v) case *array[int32]: err = writeGGUFArray(ws, ggufTypeInt32, v.values) + case []int64: + err = writeGGUFArray(ws, ggufTypeInt64, v) + case *array[int64]: + err = writeGGUFArray(ws, ggufTypeInt64, v.values) case []uint32: err = writeGGUFArray(ws, ggufTypeUint32, v) case *array[uint32]: diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go index f0c2826c..51430e3b 100644 --- a/fs/ggml/gguf_test.go +++ b/fs/ggml/gguf_test.go @@ -42,6 +42,10 @@ func TestWriteGGUF(t *testing.T) { "general.architecture": "test", "general.alignment": uint32(16), "test.key": "value", + "test.int32_key": int32(-42), + "test.int64_key": int64(-9223372036854775808), + "test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648}, + "test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}, "attention.key": "value2", "tokenizer.key": "value3", "adapter.key": "value4", @@ -55,7 +59,7 @@ func TestWriteGGUF(t *testing.T) { } defer r.Close() - ff, err := Decode(r, 0) + ff, err := Decode(r, -1) if err != nil { t.Fatal(err) } @@ -65,15 +69,19 @@ func TestWriteGGUF(t *testing.T) { "general.alignment": uint32(16), "general.parameter_count": uint64(54), "test.key": "value", + "test.int32_key": int32(-42), + "test.int64_key": int64(-9223372036854775808), + "test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}}, + "test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}}, "test.attention.key": "value2", "tokenizer.key": "value3", "adapter.key": "value4", - }, ff.KV()); diff != "" { + }, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" { t.Errorf("Mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(Tensors{ - Offset: 800, + Offset: 992, items: []*Tensor{ {Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}}, {Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}}, From 0c787231741eaa2e6d5b145e8565a62364a852b3 Mon Sep 17 00:00:00 2001 From: JJ Date: Sun, 7 Dec 2025 21:49:52 -0800 Subject: [PATCH 04/35] readme: fix broken Swollama link in community integrations (#13370) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7cca20ba..1f1560ca 100644 --- a/README.md +++ b/README.md @@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama. - [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples) - [Ollama for Swift](https://github.com/mattt/ollama-swift) -- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/) +- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama) - [GoLamify](https://github.com/prasad89/golamify) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API) From 5dae738067414d235ee386abd690faf1a8da9ff4 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 8 Dec 2025 09:48:49 -0800 Subject: [PATCH 05/35] CI: use vendor base commit in cache keys (#13348) Prevent CGO from accidentally reusing old object files from the cache across vendor updates --- .github/workflows/release.yaml | 17 ++++++++++++++--- .github/workflows/test.yaml | 9 +++++++-- Makefile.sync | 8 ++++++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 8c3b3120..b4b9602b 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -16,13 +16,15 @@ jobs: outputs: GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }} VERSION: ${{ steps.goflags.outputs.VERSION }} + vendorsha: ${{ steps.changes.outputs.vendorsha }} steps: - uses: actions/checkout@v4 - name: Set environment id: goflags run: | - echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT - echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT + echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT + echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT + echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT darwin-build: runs-on: macos-14-xlarge @@ -53,6 +55,9 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: go.mod + cache-dependency-path: | + go.sum + Makefile.sync - run: | ./scripts/build_darwin.sh - name: Log build results @@ -185,7 +190,7 @@ jobs: - uses: actions/cache@v4 with: path: ${{ github.workspace }}\.ccache - key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }} + key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }} - name: Build target "${{ matrix.preset }}" run: | Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' @@ -249,6 +254,9 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: go.mod + cache-dependency-path: | + go.sum + Makefile.sync - name: Verify gcc is actually clang run: | $ErrorActionPreference='Continue' @@ -302,6 +310,9 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: go.mod + cache-dependency-path: | + go.sum + Makefile.sync - uses: actions/download-artifact@v4 with: pattern: depends-windows* diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 08a0a714..b614d2f0 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,6 +22,7 @@ jobs: runs-on: ubuntu-latest outputs: changed: ${{ steps.changes.outputs.changed }} + vendorsha: ${{ steps.changes.outputs.vendorsha }} steps: - uses: actions/checkout@v4 with: @@ -37,6 +38,7 @@ jobs: } echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT + echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT linux: needs: [changes] @@ -83,7 +85,7 @@ jobs: - uses: actions/cache@v4 with: path: /github/home/.cache/ccache - key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }} + key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} - run: | cmake --preset ${{ matrix.preset }} ${{ matrix.flags }} cmake --build --preset ${{ matrix.preset }} --parallel @@ -178,7 +180,7 @@ jobs: - uses: actions/cache@v4 with: path: ${{ github.workspace }}\.ccache - key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }} + key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} - run: | Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' @@ -206,6 +208,9 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' + cache-dependency-path: | + go.sum + Makefile.sync - uses: actions/setup-node@v4 with: node-version: '20' diff --git a/Makefile.sync b/Makefile.sync index 4991ad84..a485d6f2 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -57,7 +57,7 @@ checkout: $(WORKDIR) $(WORKDIR): git clone $(UPSTREAM) $(WORKDIR) -.PHONE: format-patches +.PHONY: format-patches format-patches: llama/patches git -C $(WORKDIR) format-patch \ --no-signature \ @@ -66,7 +66,11 @@ format-patches: llama/patches -o $(realpath $<) \ $(FETCH_HEAD) -.PHONE: clean +.PHONY: clean clean: checkout @git -C $(WORKDIR) am --abort || true $(RM) llama/patches/.*.patched + +.PHONY: print-base +print-base: + @echo $(FETCH_HEAD) \ No newline at end of file From e082d60a2406d54cc8c13d7e408f08818e7939d1 Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:20:28 -0800 Subject: [PATCH 06/35] truncation: fixed runner truncation logic + removed server truncation (#12839) This PR consolidates all embedding prompt-length checking, truncation, and prompt token counting into the runner to ensure a single source of truth. --- integration/embed_test.go | 205 +++++++++++++++++++++++++++++++--- llm/server.go | 30 ++--- runner/llamarunner/runner.go | 13 ++- runner/ollamarunner/runner.go | 17 +-- server/routes.go | 97 ++++++++-------- server/sched_test.go | 4 +- 6 files changed, 278 insertions(+), 88 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index e155498d..f01903ee 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "errors" "math" + "strings" "testing" "time" @@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) { t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim) } - if res.PromptEvalCount != 6 { - t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount) + if res.PromptEvalCount != 8 { + t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount) } } @@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim) } - if res.PromptEvalCount != 12 { - t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount) + if res.PromptEvalCount != 16 { + t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount) } } @@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { cases := []struct { name string request api.EmbedRequest - check func(*api.EmbedResponse, error) + check func(*testing.T, *api.EmbedResponse, error) }{ { name: "target truncation", @@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Model: "all-minilm", Input: "why", }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } @@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Input: "why is the sky blue?", Options: map[string]any{"num_ctx": 3}, }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } + t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount) if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { t.Errorf("embedding mismatch (-want +got):\n%s", diff) } @@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncTrue, Options: map[string]any{"num_ctx": 3}, }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } + t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount) if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { t.Errorf("embedding mismatch (-want +got):\n%s", diff) } @@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncFalse, Options: map[string]any{"num_ctx": 3}, }, - check: func(res *api.EmbedResponse, err error) { - if err.Error() != "input exceeds maximum context length" { + check: func(t *testing.T, res *api.EmbedResponse, err error) { + if err.Error() != "the input length exceeds the context length" { t.Fatalf("expected truncation error, got: %v", err) } }, }, { - name: "input after truncate error", + name: "input after truncate error with context length of 1", request: api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncTrue, Options: map[string]any{"num_ctx": 1}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err.Error() != "input after truncation exceeds maximum context length" { t.Fatalf("expected truncation error, got: %v", err) } @@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncTrue, Options: map[string]any{"num_ctx": 0}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err.Error() != "input after truncation exceeds maximum context length" { t.Fatalf("expected truncation error, got: %v", err) } @@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Input: "why is the sky blue? Why is the sky blue? hi there my", Options: map[string]any{"num_ctx": 16}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } @@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { for _, req := range cases { t.Run(req.name, func(t *testing.T) { - req.check(embedTestHelper(ctx, client, t, req.request)) + resp, err := embedTestHelper(ctx, client, t, req.request) + req.check(t, resp, err) }) } } @@ -409,3 +414,173 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req return client.Embed(ctx, &req) } + +func TestEmbedTruncation(t *testing.T) { + // Use test deadline if set, otherwise default to 2 minutes + timeout := 2 * time.Minute + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + // Check if we're running out of time (reserve 20s for current model) + if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second { + t.Skip("skipping remaining tests to avoid timeout") + } + + // Give each model its own budget to account for first-time pulls/loads + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: model, + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } + + res, err := embedTestHelper(mctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) + + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: model, Input: "test"} + baseRes, err := embedTestHelper(mctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } + + batch := api.EmbedRequest{ + Model: model, + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(mctx, client, t, batch) + if err != nil { + t.Fatal(err) + } + + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) + }) + } +} + +// TestEmbedStatusCode tests that errors from the embedding endpoint +// properly preserve their HTTP status codes when returned to the client. +// This test specifically checks the error handling path in EmbedHandler +// where api.StatusError errors should maintain their original status code. +func TestEmbedStatusCode(t *testing.T) { + // Use test deadline if set, otherwise default to 2 minutes + timeout := 2 * time.Minute + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + // Check if we're running out of time (reserve 20s for current model) + if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second { + t.Skip("skipping remaining tests to avoid timeout") + } + + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + + // Pull the model if needed + if err := PullIfMissing(mctx, client, model); err != nil { + t.Fatal(err) + } + + t.Run("truncation error status code", func(t *testing.T) { + truncFalse := false + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: model, + Input: longInput, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(mctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error (likely 400 Bad Request) + // not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + + // Verify the error message is meaningful + if !strings.Contains(err.Error(), "context length") { + t.Errorf("expected error message to mention context length, got: %v", err) + } + }) + + t.Run("batch truncation error status code", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: model, + Input: []string{ + "short input", + strings.Repeat("very long input ", 100), + "another short input", + }, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(mctx, client, t, req) + if err == nil { + t.Fatal("expected error when one input exceeds context with truncate=false") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error, not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + }) + }) + } +} diff --git a/llm/server.go b/llm/server.go index e9d0a030..1c47601f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -69,7 +69,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, input string) ([]float32, error) + Embedding(ctx context.Context, input string) ([]float32, int, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -1629,10 +1629,11 @@ type EmbeddingRequest struct { } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding []float32 `json:"embedding"` + PromptEvalCount int `json:"prompt_eval_count"` } -func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { +func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, int, error) { logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { @@ -1641,51 +1642,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err } else { slog.Error("Failed to acquire semaphore", "error", err) } - return nil, err + return nil, 0, err } defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { - return nil, err + return nil, 0, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status) + return nil, 0, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) if err != nil { - return nil, fmt.Errorf("error marshaling embed data: %w", err) + return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { - return nil, fmt.Errorf("error creating embed request: %w", err) + return nil, 0, fmt.Errorf("error creating embed request: %w", err) } r.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("do embedding request: %w", err) + return nil, 0, fmt.Errorf("do embedding request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading embed response: %w", err) + return nil, 0, fmt.Errorf("error reading embed response: %w", err) } if resp.StatusCode >= 400 { log.Printf("llm embedding error: %s", body) - return nil, fmt.Errorf("%s", body) + return nil, 0, api.StatusError{ + StatusCode: resp.StatusCode, + ErrorMessage: string(body), + } } var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { - return nil, fmt.Errorf("unmarshal tokenize response: %w", err) + return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) } - return e.Embedding, nil + return e.Embedding, e.PromptEvalCount, nil } func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a23ddd61..0f32fd2a 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -757,13 +757,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: false, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -806,7 +806,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: embedding, + Embedding: embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 15339086..d0427662 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -146,12 +146,12 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = min(params.numKeep, s.cache.numCtx-1) if int32(len(inputs)) > s.cache.numCtx { - discard := int32(len(inputs)) - s.cache.numCtx - if !params.truncate { return nil, errorInputTooLong } + discard := int32(len(inputs)) - s.cache.numCtx + promptStart := params.numKeep + discard // If we need to truncate in the middle of a unbreakable batch, remove the entire batch @@ -996,13 +996,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: false, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -1043,7 +1043,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: <-seq.embedding, + Embedding: <-seq.embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/server/routes.go b/server/routes.go index e5e6dd5a..4dd870ed 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,6 +22,7 @@ import ( "os/signal" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -649,11 +650,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - truncate := true - if req.Truncate != nil && !*req.Truncate { - truncate = false - } - var input []string switch i := req.Input.(type) { @@ -701,55 +697,57 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - var count int - for i, s := range input { - tokens, err := r.Tokenize(c.Request.Context(), s) + ctx := c.Request.Context() + + embedWithRetry := func(text string) ([]float32, int, error) { + emb, tokCount, err := r.Embedding(ctx, text) + if err == nil { + return emb, tokCount, nil + } + + var serr api.StatusError + if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest { + return nil, 0, err + } + if req.Truncate != nil && !*req.Truncate { + return nil, 0, err + } + + tokens, err := r.Tokenize(ctx, text) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return nil, 0, err } + // TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) - return - } - - if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { - ctxLen-- - } - - if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { - ctxLen-- - } - - slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) - if ctxLen <= 0 { - // return error if the truncated input would be empty or just special tokens - c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) - return - } - - tokens = tokens[:ctxLen] - - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- } - count += len(tokens) + if len(tokens) <= ctxLen { + return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further") + } + if ctxLen <= 0 { + return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length") + } - input[i] = s + truncatedTokens := tokens[:ctxLen] + truncated, err := r.Detokenize(ctx, truncatedTokens) + if err != nil { + return nil, 0, err + } + return r.Embedding(ctx, truncated) } var g errgroup.Group embeddings := make([][]float32, len(input)) + var totalTokens uint64 for i, text := range input { g.Go(func() error { - embedding, err := r.Embedding(c.Request.Context(), text) + embedding, tokenCount, err := embedWithRetry(text) if err != nil { return err } @@ -759,12 +757,23 @@ func (s *Server) EmbedHandler(c *gin.Context) { embedding = normalize(embedding[:req.Dimensions]) } embeddings[i] = embedding + atomic.AddUint64(&totalTokens, uint64(tokenCount)) return nil }) } if err := g.Wait(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + var serr api.StatusError + if errors.As(err, &serr) { + c.AbortWithStatusJSON(serr.StatusCode, gin.H{ + "error": strings.TrimSpace(serr.ErrorMessage), + }) + return + } + + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ + "error": strings.TrimSpace(err.Error()), + }) return } @@ -773,7 +782,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: count, + PromptEvalCount: int(totalTokens), } c.JSON(http.StatusOK, resp) } @@ -819,7 +828,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return diff --git a/server/sched_test.go b/server/sched_test.go index 678be954..480aafa4 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, int, error) { + return s.embeddingResp, 0, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { From 603ceefaa67feee627e01cae1df1e0642e1c868f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Nov 2025 15:17:03 -0800 Subject: [PATCH 07/35] refactor rope change to a flatter directory structure and group the options with the function update models to call rope in one place --- ml/nn/{fast => }/rope.go | 5 ++--- ml/nn/rope/{rope.go => options.go} | 1 + model/models/deepseek2/model.go | 16 +++++++--------- model/models/deepseekocr/model_text.go | 11 +++++------ model/models/gemma2/model.go | 11 +++++++---- model/models/gemma3/model_text.go | 11 +++++++---- model/models/gemma3n/model_text.go | 11 +++++++---- model/models/gptoss/model.go | 17 ++++++++--------- model/models/llama/model.go | 13 +++++++------ model/models/llama4/model_text.go | 11 +++++++---- model/models/mistral3/model_text.go | 11 +++++++---- model/models/mistral3/model_vision.go | 8 ++++---- model/models/mllama/model_text.go | 13 ++++++++----- model/models/nomicbert/model.go | 9 ++++++--- model/models/qwen2/model.go | 13 +++++++------ model/models/qwen25vl/model_text.go | 14 ++++++++++---- model/models/qwen25vl/model_vision.go | 8 ++++---- model/models/qwen3/model.go | 3 +-- model/models/qwen3vl/model.go | 2 +- model/models/qwen3vl/model_text.go | 9 ++++----- model/models/qwen3vl/model_vision.go | 8 ++++---- 21 files changed, 114 insertions(+), 91 deletions(-) rename ml/nn/{fast => }/rope.go (71%) rename ml/nn/rope/{rope.go => options.go} (97%) diff --git a/ml/nn/fast/rope.go b/ml/nn/rope.go similarity index 71% rename from ml/nn/fast/rope.go rename to ml/nn/rope.go index b45938eb..967aa94f 100644 --- a/ml/nn/fast/rope.go +++ b/ml/nn/rope.go @@ -1,5 +1,4 @@ -// fast provides implementations of fast (fused) operations for increased performance. -package fast +package nn import ( "github.com/ollama/ollama/ml" @@ -8,7 +7,7 @@ import ( // fastRoPE is an interface for tensors that support fast rotary positional embedding. type fastRoPE interface { - RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor + RoPE(ctx ml.Context, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor } // RoPE applies rotary positional embedding to tensor `t`. diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/options.go similarity index 97% rename from ml/nn/rope/rope.go rename to ml/nn/rope/options.go index e01ac152..03cc5211 100644 --- a/ml/nn/rope/rope.go +++ b/ml/nn/rope/options.go @@ -1,3 +1,4 @@ +// Package rope provides options for RoPE package rope import "github.com/ollama/ollama/ml" diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index e3cab3b2..576076aa 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -42,13 +41,12 @@ type Options struct { kqScale float64 } -func (o Options) RoPEOptions() []func(*rope.Options) { - attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) - return []func(*rope.Options){ +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1./o.ropeScale, rope.WithOriginalContextLength(o.originalContextLength), rope.WithExtrapolationFactor(1.), - rope.WithAttentionFactor(attnFactor), - } + rope.WithAttentionFactor(float32(1.0/(1.0+0.1*math.Log(float64(o.ropeScale))))), + ) } type Attention struct { @@ -91,8 +89,8 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor compressedKV.Stride(1), compressedKV.Dim(1), ) - qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) - kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) + kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) var attention ml.Tensor @@ -327,7 +325,7 @@ func New(c fs.Config) (model.Model, error) { } func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { diff --git a/model/models/deepseekocr/model_text.go b/model/models/deepseekocr/model_text.go index 1513b138..ab6221cc 100644 --- a/model/models/deepseekocr/model_text.go +++ b/model/models/deepseekocr/model_text.go @@ -6,7 +6,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" ) @@ -20,7 +19,7 @@ type textModel struct { } func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type textOptions struct { @@ -38,8 +37,8 @@ func (o textOptions) headDim() int { return o.hiddenSize / o.numHeads } -func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { - return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) +func (o textOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) } type textBlock struct { @@ -83,8 +82,8 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso value := m.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) - query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) - key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, -1, attention.Dim(2)) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 06c71fc3..7b0aa2f0 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -22,6 +21,10 @@ type Options struct { largeModelScaling bool } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.attnKeyLen, o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type Model struct { model.Base model.SentencePiece @@ -88,7 +91,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +101,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -128,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 8d1a1be6..ddb30c41 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -20,6 +19,10 @@ type TextConfig struct { largeModelScaling bool } +func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []TextLayer `gguf:"blk"` @@ -87,7 +90,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +101,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -116,7 +119,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil } type TextMLP struct { diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 3a89afe7..89cc54b8 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -95,7 +94,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.ropeBaseLocal } - return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil } type TextScaledWordEmbedding struct { @@ -256,14 +255,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = attn.QueryNorm.Forward(ctx, query, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase) var key, value ml.Tensor if !sharedKV { key = attn.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = attn.KeyNorm.Forward(ctx, key, opts.eps) - key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase) value = attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) @@ -330,6 +329,10 @@ func (o *TextOptions) isLocal(i int) bool { return o.slidingWindowPattern[i] } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor { + return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX()) +} + func newTextModel(c fs.Config) *TextModel { return &TextModel{ TextLayers: make([]TextLayer, c.Uint("block_count")), diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index da08ed96..9d1520bf 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -9,7 +9,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -52,7 +51,7 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type Options struct { @@ -70,14 +69,14 @@ type Options struct { ropeScale float32 } -func (o Options) RoPEOptions() []func(*rope.Options) { - return []func(*rope.Options){ +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX(), rope.WithOriginalContextLength(o.originalContextLength), rope.WithExtrapolationFactor(1.), - // NOTE: ggml sets this implicitly so there's no need to set it here - // rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0), - } + // NOTE: ggml sets this implicitly so there's no need to set it here + // rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0), + ) } func (o Options) headDim() int { @@ -135,8 +134,8 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) } - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 52c66ba5..5ff4894e 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -20,6 +19,10 @@ type Options struct { eps, ropeBase, ropeScale float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type Model struct { model.Base model.TextProcessor @@ -115,7 +118,6 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeDim := cmp.Or(opts.ropeDim, headDim) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -126,8 +128,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -135,8 +137,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors), nil } type MLP struct { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 96b5d24d..c2bf0614 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -33,8 +32,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) } if opts.useQKNorm { @@ -152,6 +151,10 @@ type TextOptions struct { attentionFloorScale float64 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type TextModel struct { Layers []TextLayer `gguf:"blk"` @@ -236,5 +239,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].Attention.RopeFactors), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 624d3151..ebb7b3aa 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/model/input" ) @@ -20,6 +19,10 @@ type TextOptions struct { ropeScalingBeta float32 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -42,11 +45,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -61,7 +64,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type MLP struct { diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index d763df7a..1de0412d 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -16,8 +16,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } type VisionSelfAttention struct { @@ -36,8 +36,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 65f0a827..afd674eb 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" ) @@ -26,11 +25,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,8 +43,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers - if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil + if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { + return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil } return key, nil @@ -206,6 +205,10 @@ type TextModelOptions struct { crossAttentionLayers []int32 } +func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Transformer *TextDecoder `gguf:"blk"` diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go index 0e742dfa..2510240d 100644 --- a/model/models/nomicbert/model.go +++ b/model/models/nomicbert/model.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" @@ -37,6 +36,10 @@ type Options struct { ropeFreqBase float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX()) +} + // Single Encoder Layer type EncoderLayer struct { *Attention @@ -105,8 +108,8 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml chunks := qkv.Chunk(ctx, 1, opts.numHeads) query, key, value := chunks[0], chunks[1], chunks[2] - query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 10a1e65c..66f546ae 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -22,6 +21,10 @@ type Options struct { eps, ropeBase, ropeScale float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type Attention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -32,7 +35,6 @@ type Attention struct { func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenStates.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeDim := cmp.Or(opts.ropeDim, headDim) query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -43,8 +45,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, value := attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -123,8 +125,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } func New(c fs.Config) (model.Model, error) { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index e6c6e6c1..b4db6043 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -18,6 +17,13 @@ type TextOptions struct { eps, ropeBase, ropeScale float32 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithTypeNeoX(), + ) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -60,11 +66,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -78,7 +84,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } // MLP implements the feed-forward network component with SwiGLU activation diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 5cbb01f7..bfdafabe 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -18,8 +18,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { @@ -67,8 +67,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 483439ac..d7747364 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -9,7 +9,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -46,7 +45,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions rope.WithAttentionFactor(attnFactor), ) } - return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) } type Attention struct { diff --git a/model/models/qwen3vl/model.go b/model/models/qwen3vl/model.go index 579863ae..cb1ce8d2 100644 --- a/model/models/qwen3vl/model.go +++ b/model/models/qwen3vl/model.go @@ -195,7 +195,7 @@ func New(c fs.Config) (model.Model, error) { m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) { m.positionCache = nil positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1) - return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, positions), nil }) return &m, nil } diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go index 64a567b0..750c2473 100644 --- a/model/models/qwen3vl/model_text.go +++ b/model/models/qwen3vl/model_text.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" ) @@ -35,8 +34,8 @@ func (o TextOptions) headDim() int { return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) } -func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { - return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))), +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))), rope.WithInterleaveMRoPE(o.mropeSections), ) } @@ -64,8 +63,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) - key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go index b22ac305..761281ed 100644 --- a/model/models/qwen3vl/model_vision.go +++ b/model/models/qwen3vl/model_vision.go @@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1)) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) key := sa.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1)) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) value := sa.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1)) From d2f334c1f7822efe3470f41720dc121e5b19e891 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 8 Dec 2025 16:49:17 -0800 Subject: [PATCH 08/35] model: add rnj-1 inference support (#13354) --- convert/convert_gemma3.go | 64 +++++++++++++++++---- ml/nn/rope/options.go | 12 ++++ model/models/gemma3/model.go | 57 ++++++++++++------- model/models/gemma3/model_text.go | 94 ++++++++++++++++++++++--------- parser/parser.go | 19 +++---- parser/parser_test.go | 31 ++++++++++ 6 files changed, 208 insertions(+), 69 deletions(-) diff --git a/convert/convert_gemma3.go b/convert/convert_gemma3.go index 27b99f57..5e6e6904 100644 --- a/convert/convert_gemma3.go +++ b/convert/convert_gemma3.go @@ -2,6 +2,7 @@ package convert import ( "cmp" + "slices" "github.com/ollama/ollama/fs/ggml" ) @@ -26,16 +27,26 @@ type gemma3Model struct { NumChannels uint32 `json:"num_channels"` // num_channels 3 PatchSize uint32 `json:"patch_size"` // patch_size 14 } `json:"vision_config"` - MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` - NumAttentionHeads uint32 `json:"num_attention_heads"` - NumKeyValueHeads uint32 `json:"num_key_value_heads"` - RMSNormEPS float32 `json:"rms_norm_eps"` - HeadDim uint32 `json:"head_dim"` - FinalLogitSoftcap float32 `json:"final_logit_softcapping"` - RopeLocalTheta float32 `json:"rope_local_base_freq"` - RopeGlobalTheta float32 `json:"rope_global_base_freq"` - SlidingWindow uint32 `json:"sliding_window"` - MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + FinalLogitSoftcap float32 `json:"final_logit_softcapping"` + RopeLocalTheta float32 `json:"rope_local_base_freq"` + RopeTheta float32 `json:"rope_theta"` + SlidingWindow uint32 `json:"sliding_window"` + SlidingWindowPattern *uint32 `json:"sliding_window_pattern"` + LayerTypes []string `json:"layer_types"` + MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"` + RopeScaling *struct { + Type string `json:"rope_type"` + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + ExtrapolationFactor float32 `json:"extrapolation_factor"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + } `json:"rope_scaling"` } const ( @@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { kv["gemma3.attention.key_length"] = p.HeadDim kv["gemma3.attention.value_length"] = p.HeadDim kv["gemma3.attention.sliding_window"] = p.SlidingWindow - kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30) + + // The sliding window pattern is either provided as the sliding_window_pattern + // key (an int) or as the layer_types key (a list of strings). + if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 { + kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) { + for i := range numBlocks { + var isLocal bool + if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) { + isLocal = p.LayerTypes[i] == "sliding_attention" + } else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 { + isLocal = (i+1)%*p.SlidingWindowPattern != 0 + } + if !yield(isLocal) { + break + } + } + }) + } + if p.FinalLogitSoftcap > 0 { + kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap + } kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0) - kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0) + kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0) + if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 { + kv["gemma3.rope.scaling.type"] = "yarn" + kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor + kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings + kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0)) + kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0)) + kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0)) + } + kv["gemma3.embedding_length"] = p.HiddenSize kv["gemma3.feed_forward_length"] = p.IntermediateSize default: diff --git a/ml/nn/rope/options.go b/ml/nn/rope/options.go index 03cc5211..84b92677 100644 --- a/ml/nn/rope/options.go +++ b/ml/nn/rope/options.go @@ -58,6 +58,18 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) { } } +func WithBetaFast(betaFast float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.BetaFast = betaFast + } +} + +func WithBetaSlow(betaSlow float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.BetaSlow = betaSlow + } +} + func WithMRoPE(sections []int) func(*Options) { return func(opts *Options) { opts.Type |= 1 << 3 diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 62f51074..e595f186 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePiece + model.TextProcessor *VisionModel `gguf:"v"` *TextModel @@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i } func New(c fs.Config) (model.Model, error) { - m := Model{ - SentencePiece: model.NewSentencePiece( - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Ints("tokenizer.ggml.token_type"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{ - int32(c.Uint("tokenizer.ggml.eos_token_id")), - int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), - }, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), }, + c.Ints("tokenizer.ggml.eos_token_ids")..., ), + } + + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model") { + case "gpt2": + processor = model.NewBytePairEncoding(&vocabulary) + default: + // Previous uploads of Gemma 3 on Ollama did not have token 106 + // (i.e. "") so we need to add in case it's not already present + vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106))) + processor = model.NewSentencePiece(&vocabulary) + } + + m := Model{ + TextProcessor: processor, ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), TextModel: newTextModel(c), @@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - return m.Output.Forward(ctx, hiddenStates), nil + hiddenState := m.TextModel.Forward(ctx, batch, m.Cache) + hiddenState = m.Output.Forward(ctx, hiddenState) + + if m.TextConfig.finalLogitSoftcap > 0.0 { + hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap)) + hiddenState = hiddenState.Tanh(ctx) + hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap)) + } + + return hiddenState, nil } func init() { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index ddb30c41..f76fba74 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -2,6 +2,7 @@ package gemma3 import ( "math" + "slices" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -15,12 +16,32 @@ type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 + ropeLocalBase float32 largeModelScaling bool + slidingWindowPattern []bool + ropeBase float32 + ropeType string + ropeOriginalContext int + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + finalLogitSoftcap float32 } func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX()) + ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.ropeOriginalContext), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + rope.WithAttentionFactor(attnFactor), + rope.WithBetaFast(o.ropeBetaFast), + rope.WithBetaSlow(o.ropeBetaSlow), + ) + } + + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...) } type TextModel struct { @@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel { m := TextModel{ Layers: make([]TextLayer, numBlocks), TextConfig: &TextConfig{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: 1, - // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights - // (8 instead of 1) - // ropeScale: c.Float("rope.scaling.factor", 1.0), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), + ropeBase: c.Float("rope.freq_base", 1000000.0), + slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), + ropeType: c.String("rope.scaling.type"), + ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), + ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), + ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), + finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } + // Google's Gemma 3 release with sliding window attention does + // not use final logit softcapping, and so force it to 0.0 + // TODO (jmorganca): this should ideally be set to 0.0 in the + // model configuration instead of here, as future versions of + // models may include both sliding window attention and final + // logit softcapping. + if slices.Contains(m.TextConfig.slidingWindowPattern, true) { + m.TextConfig.finalLogitSoftcap = 0.0 + } + if numBlocks == gemma27BLayerCount { m.largeModelScaling = true } @@ -79,13 +114,26 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } +func (opts *TextConfig) ropeBaseForLayer(layer int) float32 { + if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] { + return opts.ropeLocalBase + } + + // Standard Gemma3: only every n-th layer is global, + // where n = gemmaGlobalCacheCount, otherwise use + // the local rope base + if (layer+1)%gemmaGlobalCacheCount > 0 { + return opts.ropeLocalBase + } + + // default to global rope base + return opts.ropeBase +} + func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeBase := opts.ropeLocalBase - if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase - } + ropeBase := opts.ropeBaseForLayer(layer) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) @@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextConfig.ropeLocalBase - if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextConfig.ropeGlobalBase - } - - return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil } type TextMLP struct { @@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) } - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return hiddenState + return m.OutputNorm.Forward(ctx, hiddenState, m.eps) } diff --git a/parser/parser.go b/parser/parser.go index 7d52c338..1f476444 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -300,18 +300,13 @@ func filesForModel(path string) ([]string, error) { } files = append(files, js...) - // only include tokenizer.model is tokenizer.json is not present - if !slices.ContainsFunc(files, func(s string) bool { - return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") - }) { - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { - // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob - // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) - } + // add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob) + // tokenizer.model might be a unresolved git lfs reference; error if it is + if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + files = append(files, tks...) + } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) + files = append(files, tks...) } return files, nil diff --git a/parser/parser_test.go b/parser/parser_test.go index 3300aad3..4b97e8c2 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -888,6 +888,37 @@ func TestFilesForModel(t *testing.T) { "tokenizer.json", }, }, + { + name: "safetensors with both tokenizer.json and tokenizer.model", + setup: func(dir string) error { + // Create binary content for tokenizer.model (application/octet-stream) + binaryContent := make([]byte, 512) + for i := range binaryContent { + binaryContent[i] = byte(i % 256) + } + files := []string{ + "model-00001-of-00001.safetensors", + "config.json", + "tokenizer.json", + } + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { + return err + } + } + // Write tokenizer.model as binary + if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil { + return err + } + return nil + }, + wantFiles: []string{ + "model-00001-of-00001.safetensors", + "config.json", + "tokenizer.json", + "tokenizer.model", + }, + }, { name: "safetensors with consolidated files - prefers model files", setup: func(dir string) error { From d475d1f081e5455dcfdf9e958619223565b9bf52 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 8 Dec 2025 13:17:03 -0800 Subject: [PATCH 09/35] fix: qwen2.5vl metal argsort --- ...13-add-argsort-and-cuda-copy-for-i32.patch | 169 +++++++++++++++++- .../patches/0027-interleave-multi-rope.patch | 2 +- .../src/ggml-metal/ggml-metal-embed.metal | 146 +++++++++++++++ .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 146 +++++++++++++++ 4 files changed, 455 insertions(+), 8 deletions(-) diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index 5e5bc110..26c6dca7 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -4,12 +4,12 @@ Date: Thu, 1 May 2025 13:45:12 -0700 Subject: [PATCH] add argsort and cuda copy for i32 --- - ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++ - ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++--- - ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ - ggml/src/ggml-cuda/cpy.cu | 40 +++++++++ - ggml/src/ggml-metal/ggml-metal.metal | 69 +++++++++++++++ - 5 files changed, 268 insertions(+), 12 deletions(-) + ggml/src/ggml-cpu/ops.cpp | 43 ++++++ + ggml/src/ggml-cuda/argsort.cu | 122 +++++++++++++-- + ggml/src/ggml-cuda/cpy-utils.cuh | 6 + + ggml/src/ggml-cuda/cpy.cu | 40 +++++ + ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++ + 5 files changed, 414 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2745fc54e..40666bab6 100644 @@ -292,7 +292,7 @@ index c4ceb4fc5..0e53ecc39 100644 if (can_be_transposed) { ggml_cpy_scalar_cuda diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 73b45c762..aed013a9d 100644 +index 73b45c762..8a6c834d1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4721,8 +4721,77 @@ kernel void kernel_argsort_f32_i32( @@ -373,3 +373,158 @@ index 73b45c762..aed013a9d 100644 typedef void (argsort_merge_t)( constant ggml_metal_kargs_argsort_merge & args, +@@ -4877,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32( + } + } + ++template ++kernel void kernel_argsort_merge_i32_i32( ++ constant ggml_metal_kargs_argsort_merge & args, ++ device const char * src0, ++ device const int32_t * tmp, ++ device int32_t * dst, ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ ushort3 tpitg[[thread_position_in_threadgroup]], ++ ushort3 ntg[[threads_per_threadgroup]]) { ++ ++ const int im = tgpig[0] / args.ne01; ++ const int i01 = tgpig[0] % args.ne01; ++ const int i02 = tgpig[1]; ++ const int i03 = tgpig[2]; ++ ++ const int start = im * (2 * args.len); ++ ++ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); ++ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); ++ ++ const int total = len0 + len1; ++ ++ device const int32_t * tmp0 = tmp + start ++ + i01*args.ne0 ++ + i02*args.ne0*args.ne01 ++ + i03*args.ne0*args.ne01*args.ne02; ++ ++ device const int32_t * tmp1 = tmp0 + args.len; ++ ++ dst += start ++ + i01*args.top_k ++ + i02*args.top_k*args.ne01 ++ + i03*args.top_k*args.ne01*args.ne02; ++ ++ device const int32_t * src0_row = (device const int32_t *)(src0 ++ + args.nb01*i01 ++ + args.nb02*i02 ++ + args.nb03*i03); ++ ++ if (total == 0) { ++ return; ++ } ++ ++ const int chunk = (total + ntg.x - 1) / ntg.x; ++ ++ const int k0 = tpitg.x * chunk; ++ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); ++ ++ if (k0 >= args.top_k) { ++ return; ++ } ++ ++ if (k0 >= total) { ++ return; ++ } ++ ++ int low = k0 > len1 ? k0 - len1 : 0; ++ int high = MIN(k0, len0); ++ ++ // binary-search partition (i, j) such that i + j = k ++ while (low < high) { ++ const int mid = (low + high) >> 1; ++ ++ const int32_t idx0 = tmp0[mid]; ++ const int32_t idx1 = tmp1[k0 - mid - 1]; ++ ++ const int32_t val0 = src0_row[idx0]; ++ const int32_t val1 = src0_row[idx1]; ++ ++ bool take_left; ++ if (order == GGML_SORT_ORDER_ASC) { ++ take_left = (val0 <= val1); ++ } else { ++ take_left = (val0 >= val1); ++ } ++ ++ if (take_left) { ++ low = mid + 1; ++ } else { ++ high = mid; ++ } ++ } ++ ++ int i = low; ++ int j = k0 - i; ++ ++ // keep the merge fronts into registers ++ int32_t idx0 = 0; ++ int32_t val0 = 0.0f; ++ if (i < len0) { ++ idx0 = tmp0[i]; ++ val0 = src0_row[idx0]; ++ } ++ ++ int32_t idx1 = 0; ++ int32_t val1 = 0.0f; ++ if (j < len1) { ++ idx1 = tmp1[j]; ++ val1 = src0_row[idx1]; ++ } ++ ++ for (int k = k0; k < k1; ++k) { ++ int32_t out_idx; ++ ++ if (i >= len0) { ++ while (k < k1) { ++ dst[k++] = tmp1[j++]; ++ } ++ break; ++ } else if (j >= len1) { ++ while (k < k1) { ++ dst[k++] = tmp0[i++]; ++ } ++ break; ++ } else { ++ bool take_left; ++ ++ if (order == GGML_SORT_ORDER_ASC) { ++ take_left = (val0 <= val1); ++ } else { ++ take_left = (val0 >= val1); ++ } ++ ++ if (take_left) { ++ out_idx = idx0; ++ ++i; ++ if (i < len0) { ++ idx0 = tmp0[i]; ++ val0 = src0_row[idx0]; ++ } ++ } else { ++ out_idx = idx1; ++ ++j; ++ if (j < len1) { ++ idx1 = tmp1[j]; ++ val1 = src0_row[idx1]; ++ } ++ } ++ } ++ ++ dst[k] = out_idx; ++ } ++} ++ + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; + template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; ++template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; ++template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; + + kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, diff --git a/llama/patches/0027-interleave-multi-rope.patch b/llama/patches/0027-interleave-multi-rope.patch index 1fee6b75..7d36d355 100644 --- a/llama/patches/0027-interleave-multi-rope.patch +++ b/llama/patches/0027-interleave-multi-rope.patch @@ -59,7 +59,7 @@ index 88ed79111..71ca60214 100644 } else { if (sector < sections.v[0]) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index aed013a9d..a489de435 100644 +index 8a6c834d1..761b57a26 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi( diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index da4c2bb0..9903af36 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -7723,8 +7723,154 @@ kernel void kernel_argsort_merge_f32_i32( } } +template +kernel void kernel_argsort_merge_i32_i32( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); + + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); + + const int total = len0 + len1; + + device const int32_t * tmp0 = tmp + start + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; + + device const int32_t * tmp1 = tmp0 + args.len; + + dst += start + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; + + device const int32_t * src0_row = (device const int32_t *)(src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + if (total == 0) { + return; + } + + const int chunk = (total + ntg.x - 1) / ntg.x; + + const int k0 = tpitg.x * chunk; + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } + + if (k0 >= total) { + return; + } + + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const int32_t val0 = src0_row[idx0]; + const int32_t val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + + int i = low; + int j = k0 - i; + + // keep the merge fronts into registers + int32_t idx0 = 0; + int32_t val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + int32_t val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + + for (int k = k0; k < k1; ++k) { + int32_t out_idx; + + if (i >= len0) { + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; + } else if (j >= len1) { + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; + } else { + bool take_left; + + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + out_idx = idx0; + ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + } else { + out_idx = idx1; + ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + } + } + + dst[k] = out_idx; + } +} + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; kernel void kernel_leaky_relu_f32( constant ggml_metal_kargs_leaky_relu & args, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index a489de43..761b57a2 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -4946,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32( } } +template +kernel void kernel_argsort_merge_i32_i32( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); + + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); + + const int total = len0 + len1; + + device const int32_t * tmp0 = tmp + start + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; + + device const int32_t * tmp1 = tmp0 + args.len; + + dst += start + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; + + device const int32_t * src0_row = (device const int32_t *)(src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + if (total == 0) { + return; + } + + const int chunk = (total + ntg.x - 1) / ntg.x; + + const int k0 = tpitg.x * chunk; + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } + + if (k0 >= total) { + return; + } + + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const int32_t val0 = src0_row[idx0]; + const int32_t val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + + int i = low; + int j = k0 - i; + + // keep the merge fronts into registers + int32_t idx0 = 0; + int32_t val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + int32_t val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + + for (int k = k0; k < k1; ++k) { + int32_t out_idx; + + if (i >= len0) { + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; + } else if (j >= len1) { + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; + } else { + bool take_left; + + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + out_idx = idx0; + ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + } else { + out_idx = idx1; + ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + } + } + + dst[k] = out_idx; + } +} + template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; +template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32; kernel void kernel_leaky_relu_f32( constant ggml_metal_kargs_leaky_relu & args, From 0c5e5f66304ca4b078e3cfa5b3beba20e9175100 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Tue, 9 Dec 2025 10:41:47 -0800 Subject: [PATCH 10/35] parsers/renderers: olmo3 think (#13290) --- model/parsers/olmo3_think.go | 170 ++++++++ model/parsers/olmo3_think_test.go | 390 ++++++++++++++++++ model/parsers/parsers.go | 2 + model/renderers/json.go | 45 ++ .../{qwen3vl_test.go => json_test.go} | 1 - model/renderers/olmo3_think.go | 130 ++++++ model/renderers/olmo3_think_test.go | 224 ++++++++++ model/renderers/qwen3vl.go | 40 -- model/renderers/renderer.go | 3 + 9 files changed, 964 insertions(+), 41 deletions(-) create mode 100644 model/parsers/olmo3_think.go create mode 100644 model/parsers/olmo3_think_test.go create mode 100644 model/renderers/json.go rename model/renderers/{qwen3vl_test.go => json_test.go} (99%) create mode 100644 model/renderers/olmo3_think.go create mode 100644 model/renderers/olmo3_think_test.go diff --git a/model/parsers/olmo3_think.go b/model/parsers/olmo3_think.go new file mode 100644 index 00000000..eddb9ff9 --- /dev/null +++ b/model/parsers/olmo3_think.go @@ -0,0 +1,170 @@ +package parsers + +import ( + "context" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type olmo3ThinkParserState int + +const ( + olmo3CollectingThink olmo3ThinkParserState = iota + olmo3CollectingContent +) + +const ( + olmo3ThinkCloseTag = "" +) + +type Olmo3ThinkParser struct { + state olmo3ThinkParserState + buffer strings.Builder +} + +func (p *Olmo3ThinkParser) HasToolSupport() bool { + return false +} + +func (p *Olmo3ThinkParser) HasThinkingSupport() bool { + return true +} + +func (p *Olmo3ThinkParser) setInitialState(lastMessage *api.Message) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + // If prefilling with content, skip to content collection + if prefill && lastMessage.Content != "" { + p.state = olmo3CollectingContent + return + } + + // Model always thinks first (the tag is injected in the prompt) + p.state = olmo3CollectingThink +} + +func (p *Olmo3ThinkParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.setInitialState(lastMessage) + return tools +} + +// Event types for internal parser communication +type olmo3Event interface { + isOlmo3Event() +} + +type olmo3EventThinkContent struct { + content string +} + +type olmo3EventContent struct { + content string +} + +func (olmo3EventThinkContent) isOlmo3Event() {} +func (olmo3EventContent) isOlmo3Event() {} + +func (p *Olmo3ThinkParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case olmo3EventThinkContent: + thinkingSb.WriteString(event.content) + case olmo3EventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), nil, nil +} + +func (p *Olmo3ThinkParser) parseEvents() []olmo3Event { + var all []olmo3Event + + keepLooping := true + for keepLooping { + var events []olmo3Event + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +func (p *Olmo3ThinkParser) eat() ([]olmo3Event, bool) { + var events []olmo3Event + bufStr := p.buffer.String() + if bufStr == "" { + return events, false + } + + switch p.state { + case olmo3CollectingThink: + if strings.Contains(bufStr, olmo3ThinkCloseTag) { + // Found complete tag + split := strings.SplitN(bufStr, olmo3ThinkCloseTag, 2) + thinking := strings.TrimRightFunc(split[0], unicode.IsSpace) + remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = olmo3CollectingContent + + if len(thinking) > 0 { + events = append(events, olmo3EventThinkContent{content: thinking}) + } + return events, true + } else if overlapLen := overlap(bufStr, olmo3ThinkCloseTag); overlapLen > 0 { + // Partial tag - withhold ambiguous content + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, olmo3EventThinkContent{content: unambiguous}) + } + return events, false + } else { + // Regular thinking content - withhold trailing whitespace in case follows + whitespaceLen := trailingWhitespaceLen(bufStr) + ambiguousStart := len(bufStr) - whitespaceLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, olmo3EventThinkContent{content: unambiguous}) + } + return events, false + } + + case olmo3CollectingContent: + // Emit all content directly + p.buffer.Reset() + if len(bufStr) > 0 { + events = append(events, olmo3EventContent{content: bufStr}) + } + return events, false + } + + return events, false +} diff --git a/model/parsers/olmo3_think_test.go b/model/parsers/olmo3_think_test.go new file mode 100644 index 00000000..9479cef8 --- /dev/null +++ b/model/parsers/olmo3_think_test.go @@ -0,0 +1,390 @@ +package parsers + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestOlmo3ThinkParser(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + lastMessage *api.Message + }{ + { + name: "thinking_only", + input: "I need to think about this.Here is my response.", + expectedContent: "Here is my response.", + expectedThinking: "I need to think about this.", + }, + { + name: "thinking_with_newlines", + input: "Let me think step by step.\n\n1. First point\n2. Second pointThe answer is 42.", + expectedContent: "The answer is 42.", + expectedThinking: "Let me think step by step.\n\n1. First point\n2. Second point", + }, + { + name: "thinking_then_content", + input: "Deep thinking here.Here is my detailed response with multiple sentences. I have thought carefully.", + expectedContent: "Here is my detailed response with multiple sentences. I have thought carefully.", + expectedThinking: "Deep thinking here.", + }, + { + name: "empty_thinking", + input: "Just content here.", + expectedContent: "Just content here.", + expectedThinking: "", + }, + { + name: "prefill_skips_thinking", + input: "Continuing from previous content.", + expectedContent: "Continuing from previous content.", + lastMessage: &api.Message{ + Role: "assistant", + Content: "Previous content", + }, + }, + { + name: "thinking_with_whitespace", + input: " Some thinking Content here ", + expectedContent: "Content here ", + expectedThinking: " Some thinking", + }, + { + name: "real_model_output_with_newlines", + input: "Yes, that should work. Let me go with that response.\n\n\n\nHi! I'm all set and ready to assist. How about you? How are you today? 😊", + expectedThinking: "Yes, that should work. Let me go with that response.", + expectedContent: "Hi! I'm all set and ready to assist. How about you? How are you today? 😊", + }, + // Edge cases + { + name: "nested_think_tags_in_thinking", + input: "I'm thinking nested more thinkingFinal content.", + expectedContent: "more thinkingFinal content.", + expectedThinking: "I'm thinking nested", + }, + { + name: "multiple_think_close_tags", + input: "First thinkingContentMore content.", + expectedContent: "ContentMore content.", + expectedThinking: "First thinking", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Olmo3ThinkParser{} + parser.Init(nil, tt.lastMessage, nil) + + content, thinking, toolCalls, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("thinking mismatch (-want +got):\n%s", diff) + } + + // No tool calls expected + if len(toolCalls) > 0 { + t.Errorf("expected no tool calls, got %d", len(toolCalls)) + } + }) + } +} + +func TestOlmo3ThinkParser_Streaming(t *testing.T) { + parser := &Olmo3ThinkParser{} + parser.Init(nil, nil, nil) + + chunks := []string{ + "I am ", + "thinking about", + " this.Here ", + "is the response.", + } + + var finalContent, finalThinking strings.Builder + + for i, chunk := range chunks { + done := i == len(chunks)-1 + content, thinking, _, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error on chunk %d: %v", i, err) + } + + finalContent.WriteString(content) + finalThinking.WriteString(thinking) + } + + expectedContent := "Here is the response." + expectedThinking := "I am thinking about this." + + if finalContent.String() != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, finalContent.String()) + } + + if finalThinking.String() != expectedThinking { + t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String()) + } +} + +func TestOlmo3ThinkParser_StreamingEdgeCases(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedContent string + expectedThinking string + }{ + { + name: "thinking_tag_split_across_chunks", + chunks: []string{ + "This is thinking content", + "", + "This is content.", + }, + expectedContent: "This is content.", + expectedThinking: "This is thinking content", + }, + { + name: "thinking_tag_split_mid_token", + chunks: []string{ + "Thinking?", + "Content here.", + }, + expectedContent: "Content here.", + expectedThinking: "Thinking?", + }, + { + name: "thinking_tag_split_at_angle_bracket", + chunks: []string{ + "Thinking<", + "/think>", + "Content.", + }, + expectedContent: "Content.", + expectedThinking: "Thinking", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Olmo3ThinkParser{} + parser.Init(nil, nil, nil) + + var finalContent, finalThinking strings.Builder + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, thinking, _, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error on chunk %d: %v", i, err) + } + + finalContent.WriteString(content) + finalThinking.WriteString(thinking) + } + + if finalContent.String() != tt.expectedContent { + t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String()) + } + + if finalThinking.String() != tt.expectedThinking { + t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String()) + } + }) + } +} + +// TestOlmo3ThinkParser_ThinkBoundary tests streaming thinking content +// where thinking chunks come in succession before the tag +func TestOlmo3ThinkParser_ThinkBoundary(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedThinking string + expectedContent string + }{ + { + name: "multiple_thinking_chunks", + chunks: []string{ + "First part of thinking. ", + "Second part of thinking. ", + "Third part.", + "Content here.", + }, + expectedThinking: "First part of thinking. Second part of thinking. Third part.", + expectedContent: "Content here.", + }, + { + name: "thinking_chunks_with_newlines", + chunks: []string{ + "Step 1: Analyze the problem.\n", + "Step 2: Consider options.\n", + "Step 3: Make decision.", + "Here is my answer.", + }, + expectedThinking: "Step 1: Analyze the problem.\nStep 2: Consider options.\nStep 3: Make decision.", + expectedContent: "Here is my answer.", + }, + { + name: "single_char_thinking_chunks", + chunks: []string{ + "H", "e", "l", "l", "o", "", "World", + }, + expectedThinking: "Hello", + expectedContent: "World", + }, + { + name: "thinking_with_special_chars", + chunks: []string{ + "Let me think... ", + "Option A: $100 ", + "Option B: €200", + "I recommend Option A.", + }, + expectedThinking: "Let me think... Option A: $100 Option B: €200", + expectedContent: "I recommend Option A.", + }, + { + name: "long_thinking_multiple_chunks", + chunks: []string{ + "This is a very long thinking process. ", + "I need to consider many factors. ", + "First, let me look at the data. ", + "The numbers show interesting patterns. ", + "Based on my analysis, ", + "I can conclude that...", + "The answer is 42.", + }, + expectedThinking: "This is a very long thinking process. I need to consider many factors. First, let me look at the data. The numbers show interesting patterns. Based on my analysis, I can conclude that...", + expectedContent: "The answer is 42.", + }, + { + name: "thinking_ends_exactly_at_chunk_boundary", + chunks: []string{ + "Thinking content", + "", + "Content", + }, + expectedThinking: "Thinking content", + expectedContent: "Content", + }, + { + name: "empty_chunks_between_thinking", + chunks: []string{ + "Start thinking", + "", + " middle ", + "", + "end", + "Content", + }, + expectedThinking: "Start thinking middle end", + expectedContent: "Content", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Olmo3ThinkParser{} + parser.Init(nil, nil, nil) + + var finalContent, finalThinking strings.Builder + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, thinking, _, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error on chunk %d: %v", i, err) + } + + finalContent.WriteString(content) + finalThinking.WriteString(thinking) + } + + if finalThinking.String() != tt.expectedThinking { + t.Errorf("thinking mismatch:\nexpected: %q\ngot: %q", tt.expectedThinking, finalThinking.String()) + } + + if finalContent.String() != tt.expectedContent { + t.Errorf("content mismatch:\nexpected: %q\ngot: %q", tt.expectedContent, finalContent.String()) + } + }) + } +} + +// TestOlmo3ThinkParser_StateTransitions tests that state transitions work correctly +func TestOlmo3ThinkParser_StateTransitions(t *testing.T) { + t.Run("thinking_to_content", func(t *testing.T) { + parser := &Olmo3ThinkParser{} + parser.Init(nil, nil, nil) + + if parser.state != olmo3CollectingThink { + t.Errorf("initial state should be olmo3CollectingThink, got %v", parser.state) + } + + parser.Add("thinkingcontent", true) + + if parser.state != olmo3CollectingContent { + t.Errorf("state after should be olmo3CollectingContent, got %v", parser.state) + } + }) +} + +func TestOlmo3ThinkParser_HasToolSupport(t *testing.T) { + parser := &Olmo3ThinkParser{} + if parser.HasToolSupport() { + t.Error("Olmo3ThinkParser should NOT support tools") + } +} + +func TestOlmo3ThinkParser_HasThinkingSupport(t *testing.T) { + parser := &Olmo3ThinkParser{} + if !parser.HasThinkingSupport() { + t.Error("Olmo3ThinkParser should support thinking") + } +} + +func TestOlmo3ThinkParser_Init(t *testing.T) { + parser := &Olmo3ThinkParser{} + + tools := []api.Tool{ + {Function: api.ToolFunction{Name: "test_tool"}}, + } + + lastMessage := &api.Message{Role: "assistant", Content: "previous"} + + returnedTools := parser.Init(tools, lastMessage, nil) + + if len(returnedTools) != len(tools) { + t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools)) + } + + // Should be in content collection mode due to prefill + if parser.state != olmo3CollectingContent { + t.Errorf("expected state olmo3CollectingContent, got %v", parser.state) + } +} + +func TestOlmo3ThinkParser_InitWithoutPrefill(t *testing.T) { + parser := &Olmo3ThinkParser{} + + parser.Init(nil, nil, nil) + + // Should be in thinking collection mode (model always thinks first) + if parser.state != olmo3CollectingThink { + t.Errorf("expected state olmo3CollectingThink, got %v", parser.state) + } +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 24ab07fb..4e15dc93 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -58,6 +58,8 @@ func ParserForName(name string) Parser { return harmony.NewHarmonyMessageHandler() case "cogito": return &CogitoParser{} + case "olmo3-think": + return &Olmo3ThinkParser{} default: return nil } diff --git a/model/renderers/json.go b/model/renderers/json.go new file mode 100644 index 00000000..76d46a90 --- /dev/null +++ b/model/renderers/json.go @@ -0,0 +1,45 @@ +package renderers + +import "encoding/json" + +// marshalWithSpaces marshals v to JSON and adds a space after each ':' and ',' +// that appears outside of string values. This matches the formatting expected +// by certain model architectures. +func marshalWithSpaces(v any) ([]byte, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + + out := make([]byte, 0, len(b)+len(b)/8) + inStr, esc := false, false + for _, c := range b { + if inStr { + out = append(out, c) + if esc { + esc = false + continue + } + if c == '\\' { + esc = true + continue + } + if c == '"' { + inStr = false + } + continue + } + switch c { + case '"': + inStr = true + out = append(out, c) + case ':': + out = append(out, ':', ' ') + case ',': + out = append(out, ',', ' ') + default: + out = append(out, c) + } + } + return out, nil +} diff --git a/model/renderers/qwen3vl_test.go b/model/renderers/json_test.go similarity index 99% rename from model/renderers/qwen3vl_test.go rename to model/renderers/json_test.go index 6810a7c9..c1ed05b9 100644 --- a/model/renderers/qwen3vl_test.go +++ b/model/renderers/json_test.go @@ -6,7 +6,6 @@ import ( "github.com/google/go-cmp/cmp" ) -// TODO(drifkin): this will be moved to utils in the near future and used by other renderers as well func TestMarshalWithSpaces(t *testing.T) { tests := []struct { name string diff --git a/model/renderers/olmo3_think.go b/model/renderers/olmo3_think.go new file mode 100644 index 00000000..b327d044 --- /dev/null +++ b/model/renderers/olmo3_think.go @@ -0,0 +1,130 @@ +package renderers + +import ( + "encoding/json" + "strings" + + "github.com/ollama/ollama/api" +) + +const ( + olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai." + olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions." +) + +type Olmo3ThinkRenderer struct{} + +type olmo3ThinkToolCall struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function olmo3ThinkToolCallFunc `json:"function"` +} + +type olmo3ThinkToolCallFunc struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + var systemMessage *api.Message + filteredMessages := make([]api.Message, 0, len(messages)) + for i, message := range messages { + if message.Role == "system" { + if systemMessage == nil { + systemMessage = &messages[i] + } + continue + } + filteredMessages = append(filteredMessages, message) + } + + systemContent := olmo3ThinkDefaultSystemMessage + if systemMessage != nil { + systemContent = systemMessage.Content + } + + sb.WriteString("<|im_start|>system\n") + sb.WriteString(systemContent) + + if len(tools) > 0 { + functionsJSON, err := marshalWithSpaces(tools) + if err != nil { + return "", err + } + sb.WriteString(" ") + sb.WriteString(string(functionsJSON)) + sb.WriteString("") + } else { + sb.WriteString(olmo3ThinkNoFunctionsMessage) + sb.WriteString(" ") + } + sb.WriteString("<|im_end|>\n") + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + + switch message.Role { + case "user": + sb.WriteString("<|im_start|>user\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + + case "assistant": + sb.WriteString("<|im_start|>assistant\n") + + if message.Content != "" { + sb.WriteString(message.Content) + } + + if len(message.ToolCalls) > 0 { + toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls)) + for j, tc := range message.ToolCalls { + argsJSON, err := json.Marshal(tc.Function.Arguments) + if err != nil { + return "", err + } + toolCalls[j] = olmo3ThinkToolCall{ + ID: tc.ID, + Type: "function", + Function: olmo3ThinkToolCallFunc{ + Name: tc.Function.Name, + Arguments: string(argsJSON), + }, + } + } + toolCallsJSON, err := marshalWithSpaces(toolCalls) + if err != nil { + return "", err + } + sb.WriteString("") + sb.WriteString(string(toolCallsJSON)) + sb.WriteString("") + } + + if !lastMessage { + sb.WriteString("<|im_end|>\n") + } + + case "tool": + sb.WriteString("<|im_start|>environment\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + } + } + + needsGenerationPrompt := true + if len(filteredMessages) > 0 { + lastMsg := filteredMessages[len(filteredMessages)-1] + if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" { + needsGenerationPrompt = false + } + } + + if needsGenerationPrompt { + sb.WriteString("<|im_start|>assistant\n") + } + + return sb.String(), nil +} diff --git a/model/renderers/olmo3_think_test.go b/model/renderers/olmo3_think_test.go new file mode 100644 index 00000000..21e333e3 --- /dev/null +++ b/model/renderers/olmo3_think_test.go @@ -0,0 +1,224 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestOlmo3ThinkRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic without system - adds default system", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello!<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "with system message no tools", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful assistant. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello!<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "with system message and tools", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What is the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + `You are a helpful assistant. [{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "What is the weather?<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "assistant with tool calls", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What is the weather in SF?"}, + { + Role: "assistant", + Content: "Let me check the weather.", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "location": "San Francisco", + }, + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + `You are a helpful assistant. [{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "What is the weather in SF?<|im_end|>\n" + + "<|im_start|>assistant\n" + + `Let me check the weather.[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 68}<|im_end|>` + "\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "multi-turn conversation", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + {Role: "user", Content: "How are you?"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful assistant. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello<|im_end|>\n" + + "<|im_start|>assistant\n" + + "Hi there!<|im_end|>\n" + + "<|im_start|>user\n" + + "How are you?<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "user", Content: "Get weather in SF and NYC"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "San Francisco"}, + }, + }, + { + ID: "call_2", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "New York"}, + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, + {Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + `You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. [{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "Get weather in SF and NYC<|im_end|>\n" + + "<|im_start|>assistant\n" + + `[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 68}<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 55}<|im_end|>` + "\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "assistant message only content no tool calls", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a joke"}, + {Role: "assistant", Content: "Why did the chicken cross the road?"}, + {Role: "user", Content: "I don't know, why?"}, + }, + expected: "<|im_start|>system\n" + + "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Tell me a joke<|im_end|>\n" + + "<|im_start|>assistant\n" + + "Why did the chicken cross the road?<|im_end|>\n" + + "<|im_start|>user\n" + + "I don't know, why?<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/qwen3vl.go b/model/renderers/qwen3vl.go index 8ea4abbb..50879d29 100644 --- a/model/renderers/qwen3vl.go +++ b/model/renderers/qwen3vl.go @@ -1,51 +1,11 @@ package renderers import ( - "encoding/json" "strings" "github.com/ollama/ollama/api" ) -func marshalWithSpaces(v any) ([]byte, error) { - b, err := json.Marshal(v) - if err != nil { - return nil, err - } - - out := make([]byte, 0, len(b)+len(b)/8) - inStr, esc := false, false - for _, c := range b { - if inStr { - out = append(out, c) - if esc { - esc = false - continue - } - if c == '\\' { - esc = true - continue - } - if c == '"' { - inStr = false - } - continue - } - switch c { - case '"': - inStr = true - out = append(out, c) - case ':': - out = append(out, ':', ' ') - case ',': - out = append(out, ',', ' ') - default: - out = append(out, c) - } - } - return out, nil -} - type Qwen3VLRenderer struct { isThinking bool diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 84df1b8a..098b16a8 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -59,6 +59,9 @@ func rendererForName(name string) Renderer { case "cogito": renderer := &CogitoRenderer{isThinking: true} return renderer + case "olmo3-think": + renderer := &Olmo3ThinkRenderer{} + return renderer default: return nil } From 2bccf8c6249eed1e85758d41eeb614d951cafec4 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Tue, 9 Dec 2025 11:12:27 -0800 Subject: [PATCH 11/35] renderers/parsers: olmo3 instruct (#13383) --- model/parsers/olmo3.go | 465 ++++++++++++++++++++++++++++++++ model/parsers/olmo3_test.go | 483 ++++++++++++++++++++++++++++++++++ model/parsers/parsers.go | 2 + model/renderers/olmo3.go | 147 +++++++++++ model/renderers/olmo3_test.go | 290 ++++++++++++++++++++ model/renderers/renderer.go | 3 + 6 files changed, 1390 insertions(+) create mode 100644 model/parsers/olmo3.go create mode 100644 model/parsers/olmo3_test.go create mode 100644 model/renderers/olmo3.go create mode 100644 model/renderers/olmo3_test.go diff --git a/model/parsers/olmo3.go b/model/parsers/olmo3.go new file mode 100644 index 00000000..ee4037a6 --- /dev/null +++ b/model/parsers/olmo3.go @@ -0,0 +1,465 @@ +package parsers + +import ( + "context" + "fmt" + "log/slog" + "regexp" + "strconv" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type olmo3ParserState int + +const ( + olmo3StateContent olmo3ParserState = iota + olmo3StateToolCalls + olmo3StateToolCallsDone +) + +const ( + olmo3FuncCallsOpenTag = "" + olmo3FuncCallsCloseTag = "" +) + +type Olmo3Parser struct { + state olmo3ParserState + buffer strings.Builder +} + +func (p *Olmo3Parser) HasToolSupport() bool { + return true +} + +func (p *Olmo3Parser) HasThinkingSupport() bool { + return false +} + +func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.state = olmo3StateContent + return tools +} + +type olmo3ParserEvent interface { + isOlmo3ParserEvent() +} + +type olmo3ParserEventContent struct { + content string +} + +type olmo3ParserEventToolCalls struct { + calls []api.ToolCall +} + +func (olmo3ParserEventContent) isOlmo3ParserEvent() {} +func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {} + +func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + + if done { + // Drain any remaining content + bufStr := p.buffer.String() + p.buffer.Reset() + if p.state == olmo3StateContent && len(bufStr) > 0 { + return bufStr, "", nil, nil + } + return "", "", nil, nil + } + + events := p.parseEvents() + + var contentSb strings.Builder + var allCalls []api.ToolCall + for _, event := range events { + switch event := event.(type) { + case olmo3ParserEventContent: + contentSb.WriteString(event.content) + case olmo3ParserEventToolCalls: + allCalls = append(allCalls, event.calls...) + } + } + + return contentSb.String(), "", allCalls, nil +} + +func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent { + var all []olmo3ParserEvent + + keepLooping := true + for keepLooping { + var events []olmo3ParserEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) { + var events []olmo3ParserEvent + bufStr := p.buffer.String() + if bufStr == "" { + return events, false + } + + switch p.state { + case olmo3StateContent: + if strings.Contains(bufStr, olmo3FuncCallsOpenTag) { + // Found tag + split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2) + content := split[0] + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = olmo3StateToolCalls + + if len(content) > 0 { + events = append(events, olmo3ParserEventContent{content: content}) + } + return events, true + } else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 { + // Partial tag - withhold ambiguous content + unambiguous := bufStr[:len(bufStr)-overlapLen] + ambiguous := bufStr[len(bufStr)-overlapLen:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, olmo3ParserEventContent{content: unambiguous}) + } + return events, false + } else { + // Regular content - emit all + p.buffer.Reset() + if len(bufStr) > 0 { + events = append(events, olmo3ParserEventContent{content: bufStr}) + } + return events, false + } + + case olmo3StateToolCalls: + if strings.Contains(bufStr, olmo3FuncCallsCloseTag) { + // Found tag + split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2) + toolCallsStr := split[0] + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = olmo3StateToolCallsDone + + // Parse the function calls + calls, err := parseOlmo3FunctionCalls(toolCallsStr) + if err != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr) + } else if len(calls) > 0 { + events = append(events, olmo3ParserEventToolCalls{calls: calls}) + } + return events, true + } else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 { + // Partial tag - wait for more + return events, false + } + // Still collecting tool calls, wait for close tag + return events, false + + case olmo3StateToolCallsDone: + // After tool calls, emit remaining content + p.buffer.Reset() + p.state = olmo3StateContent + if len(bufStr) > 0 { + events = append(events, olmo3ParserEventContent{content: bufStr}) + } + return events, false + } + + return events, false +} + +// parseOlmo3FunctionCalls parses function calls in Python-esque format: +// func_name(arg1="value1", arg2=123) +// Multiple calls are separated by newlines +func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) { + var calls []api.ToolCall + s = strings.TrimSpace(s) + if s == "" { + return calls, nil + } + + // Split by newlines for multiple function calls + lines := strings.Split(s, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + call, err := parseOlmo3SingleFunctionCall(line) + if err != nil { + return nil, fmt.Errorf("failed to parse function call %q: %w", line, err) + } + calls = append(calls, call) + } + + return calls, nil +} + +// Regex to match function call: func_name(args) +var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`) + +func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) { + matches := funcCallRegex.FindStringSubmatch(s) + if matches == nil { + return api.ToolCall{}, fmt.Errorf("invalid function call format") + } + + funcName := matches[1] + argsStr := matches[2] + + args, err := parseOlmo3Arguments(argsStr) + if err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err) + } + + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: funcName, + Arguments: args, + }, + }, nil +} + +// parseOlmo3Arguments parses comma-separated key=value pairs +// Handles nested parentheses, brackets, braces, and quoted strings +func parseOlmo3Arguments(s string) (map[string]any, error) { + args := make(map[string]any) + s = strings.TrimSpace(s) + if s == "" { + return args, nil + } + + // Split by commas, but respect nested structures and quotes + parts := splitArguments(s) + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + // Find the first = sign + eqIdx := strings.Index(part, "=") + if eqIdx == -1 { + return nil, fmt.Errorf("invalid argument format: %s", part) + } + + key := strings.TrimSpace(part[:eqIdx]) + valueStr := strings.TrimSpace(part[eqIdx+1:]) + + value, err := parseOlmo3Value(valueStr) + if err != nil { + return nil, fmt.Errorf("failed to parse value for %s: %w", key, err) + } + + args[key] = value + } + + return args, nil +} + +// splitArguments splits arguments by commas, respecting quotes and nested structures +func splitArguments(s string) []string { + var parts []string + var current strings.Builder + depth := 0 + inString := false + stringChar := byte(0) + escaped := false + + for i := range s { + c := s[i] + + if escaped { + current.WriteByte(c) + escaped = false + continue + } + + if c == '\\' && inString { + current.WriteByte(c) + escaped = true + continue + } + + if (c == '"' || c == '\'') && !inString { + inString = true + stringChar = c + current.WriteByte(c) + continue + } + + if c == stringChar && inString { + inString = false + stringChar = 0 + current.WriteByte(c) + continue + } + + if !inString { + switch c { + case '(', '[', '{': + depth++ + current.WriteByte(c) + case ')', ']', '}': + depth-- + current.WriteByte(c) + case ',': + if depth == 0 { + parts = append(parts, current.String()) + current.Reset() + continue + } + current.WriteByte(c) + default: + current.WriteByte(c) + } + } else { + current.WriteByte(c) + } + } + + if current.Len() > 0 { + parts = append(parts, current.String()) + } + + return parts +} + +// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object +func parseOlmo3Value(s string) (any, error) { + s = strings.TrimSpace(s) + + // Check for quoted string + if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) || + (strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) { + // Remove quotes and unescape + inner := s[1 : len(s)-1] + return unescapeString(inner), nil + } + + // Check for boolean + if s == "true" || s == "True" { + return true, nil + } + if s == "false" || s == "False" { + return false, nil + } + + // Check for null/None + if s == "null" || s == "None" || s == "nil" { + return nil, nil + } + + // Check for number + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + return i, nil + } + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f, nil + } + + // Check for array [...] + if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") { + return parseOlmo3Array(s[1 : len(s)-1]) + } + + // Check for object {...} + if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") { + return parseOlmo3Object(s[1 : len(s)-1]) + } + + // Default to string without quotes + return s, nil +} + +func parseOlmo3Array(s string) ([]any, error) { + s = strings.TrimSpace(s) + if s == "" { + return []any{}, nil + } + + parts := splitArguments(s) + var arr []any + for _, part := range parts { + val, err := parseOlmo3Value(part) + if err != nil { + return nil, err + } + arr = append(arr, val) + } + return arr, nil +} + +func parseOlmo3Object(s string) (map[string]any, error) { + s = strings.TrimSpace(s) + if s == "" { + return map[string]any{}, nil + } + + // Objects use key: value or "key": value format + obj := make(map[string]any) + parts := splitArguments(s) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + // Find colon separator + colonIdx := strings.Index(part, ":") + if colonIdx == -1 { + return nil, fmt.Errorf("invalid object entry: %s", part) + } + + keyStr := strings.TrimSpace(part[:colonIdx]) + valueStr := strings.TrimSpace(part[colonIdx+1:]) + + // Remove quotes from key if present + if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) || + (strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) { + keyStr = keyStr[1 : len(keyStr)-1] + } + + val, err := parseOlmo3Value(valueStr) + if err != nil { + return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err) + } + + obj[keyStr] = val + } + + return obj, nil +} + +func unescapeString(s string) string { + // Handle common escape sequences + s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash + s = strings.ReplaceAll(s, `\"`, `"`) + s = strings.ReplaceAll(s, `\'`, `'`) + s = strings.ReplaceAll(s, `\n`, "\n") + s = strings.ReplaceAll(s, `\t`, "\t") + s = strings.ReplaceAll(s, `\r`, "\r") + s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash + return s +} diff --git a/model/parsers/olmo3_test.go b/model/parsers/olmo3_test.go new file mode 100644 index 00000000..6c5b57b8 --- /dev/null +++ b/model/parsers/olmo3_test.go @@ -0,0 +1,483 @@ +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestOlmo3Parser(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + }{ + { + name: "simple content", + input: "Hello, how can I help you?", + expectedContent: "Hello, how can I help you?", + }, + { + name: "simple tool call", + input: `get_weather(location="San Francisco")`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "San Francisco"}, + }, + }, + }, + }, + { + name: "content then tool call", + input: `Let me check the weather.get_weather(location="NYC")`, + expectedContent: "Let me check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "NYC"}, + }, + }, + }, + }, + { + name: "tool call with multiple arguments", + input: `book_flight(from="SFO", to="NYC", date="2024-01-15")`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + "date": "2024-01-15", + }, + }, + }, + }, + }, + { + name: "multiple tool calls", + input: `get_weather(location="San Francisco") +get_weather(location="New York")`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "San Francisco"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "New York"}, + }, + }, + }, + }, + { + name: "tool call with numeric argument", + input: `set_temperature(value=72)`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_temperature", + Arguments: map[string]any{"value": int64(72)}, + }, + }, + }, + }, + { + name: "tool call with float argument", + input: `set_price(amount=19.99)`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_price", + Arguments: map[string]any{"amount": 19.99}, + }, + }, + }, + }, + { + name: "tool call with boolean argument", + input: `toggle_setting(enabled=true)`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "toggle_setting", + Arguments: map[string]any{"enabled": true}, + }, + }, + }, + }, + { + name: "tool call with null argument", + input: `clear_value(field=null)`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "clear_value", + Arguments: map[string]any{"field": nil}, + }, + }, + }, + }, + { + name: "tool call with array argument", + input: `process_items(items=["apple", "banana", "cherry"])`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process_items", + Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}}, + }, + }, + }, + }, + { + name: "tool call with dict argument", + input: `update_config(settings={"theme": "dark", "fontSize": 14})`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "update_config", + Arguments: map[string]any{ + "settings": map[string]any{ + "theme": "dark", + "fontSize": int64(14), + }, + }, + }, + }, + }, + }, + { + name: "tool call with nested dict", + input: `create_request(data={"user": {"name": "John", "age": 30}, "active": true})`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_request", + Arguments: map[string]any{ + "data": map[string]any{ + "user": map[string]any{ + "name": "John", + "age": int64(30), + }, + "active": true, + }, + }, + }, + }, + }, + }, + { + name: "tool call with no arguments", + input: `get_current_time()`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_time", + Arguments: map[string]any{}, + }, + }, + }, + }, + { + name: "tool call with single quotes", + input: `search(query='hello world')`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": "hello world"}, + }, + }, + }, + }, + { + name: "tool call with escaped quotes", + input: `search(query="say \"hello\"")`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": `say "hello"`}, + }, + }, + }, + }, + { + name: "tool call with mixed argument types", + input: `create_user(name="John", age=30, active=true)`, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_user", + Arguments: map[string]any{ + "name": "John", + "age": int64(30), + "active": true, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Olmo3Parser{} + p.Init(nil, nil, nil) + + content, thinking, calls, err := p.Add(tt.input, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Drain remaining content + finalContent, finalThinking, finalCalls, err := p.Add("", true) + if err != nil { + t.Fatalf("unexpected error on done: %v", err) + } + content += finalContent + thinking += finalThinking + calls = append(calls, finalCalls...) + + if diff := cmp.Diff(content, tt.expectedContent); diff != "" { + t.Errorf("content mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { + t.Errorf("thinking mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" { + t.Errorf("calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestOlmo3Parser_Streaming(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedContent string + expectedCalls []api.ToolCall + }{ + { + name: "streaming content", + chunks: []string{"Hello, ", "how ", "can I help?"}, + expectedContent: "Hello, how can I help?", + }, + { + name: "streaming tool call", + chunks: []string{"get_weather", "(location=\"SF\")", ""}, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "SF"}, + }, + }, + }, + }, + { + name: "streaming content then tool call", + chunks: []string{"Let me check.", "", "get_weather(location=\"NYC\")", ""}, + expectedContent: "Let me check.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "NYC"}, + }, + }, + }, + }, + { + name: "tool call tag split across chunks", + chunks: []string{"test()"}, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Olmo3Parser{} + p.Init(nil, nil, nil) + + var allContent string + var allCalls []api.ToolCall + + for _, chunk := range tt.chunks { + content, _, calls, err := p.Add(chunk, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + allContent += content + allCalls = append(allCalls, calls...) + } + + // Drain + content, _, calls, err := p.Add("", true) + if err != nil { + t.Fatalf("unexpected error on done: %v", err) + } + allContent += content + allCalls = append(allCalls, calls...) + + if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" { + t.Errorf("content mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" { + t.Errorf("calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestOlmo3Parser_HasToolSupport(t *testing.T) { + p := &Olmo3Parser{} + if !p.HasToolSupport() { + t.Error("expected HasToolSupport to return true") + } +} + +func TestOlmo3Parser_HasThinkingSupport(t *testing.T) { + p := &Olmo3Parser{} + if p.HasThinkingSupport() { + t.Error("expected HasThinkingSupport to return false") + } +} + +func TestParseOlmo3FunctionCalls(t *testing.T) { + tests := []struct { + name string + input string + expected []api.ToolCall + wantErr bool + }{ + { + name: "simple call", + input: `get_weather(location="SF")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "SF"}, + }, + }, + }, + }, + { + name: "multiple args", + input: `send_email(to="user@example.com", subject="Hello", body="Test message")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "send_email", + Arguments: map[string]any{ + "to": "user@example.com", + "subject": "Hello", + "body": "Test message", + }, + }, + }, + }, + }, + { + name: "multiple calls with newlines", + input: `get_weather(location="SF") +get_time(timezone="PST")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "SF"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_time", + Arguments: map[string]any{"timezone": "PST"}, + }, + }, + }, + }, + { + name: "empty input", + input: "", + expected: nil, + }, + { + name: "whitespace only", + input: " \n ", + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + calls, err := parseOlmo3FunctionCalls(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(calls, tt.expected); diff != "" { + t.Errorf("calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParseOlmo3Value(t *testing.T) { + tests := []struct { + name string + input string + expected any + }{ + {"string double quotes", `"hello"`, "hello"}, + {"string single quotes", `'hello'`, "hello"}, + {"integer", "42", int64(42)}, + {"negative integer", "-10", int64(-10)}, + {"float", "3.14", 3.14}, + {"boolean true", "true", true}, + {"boolean True", "True", true}, + {"boolean false", "false", false}, + {"null", "null", nil}, + {"None", "None", nil}, + {"empty array", "[]", []any{}}, + {"array with strings", `["a", "b"]`, []any{"a", "b"}}, + {"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}}, + {"empty object", "{}", map[string]any{}}, + {"simple object", `{"name": "John"}`, map[string]any{"name": "John"}}, + {"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}}, + {"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseOlmo3Value(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff := cmp.Diff(result, tt.expected); diff != "" { + t.Errorf("value mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 4e15dc93..ab52267c 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -58,6 +58,8 @@ func ParserForName(name string) Parser { return harmony.NewHarmonyMessageHandler() case "cogito": return &CogitoParser{} + case "olmo3": + return &Olmo3Parser{} case "olmo3-think": return &Olmo3ThinkParser{} default: diff --git a/model/renderers/olmo3.go b/model/renderers/olmo3.go new file mode 100644 index 00000000..24ade20d --- /dev/null +++ b/model/renderers/olmo3.go @@ -0,0 +1,147 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/ollama/ollama/api" +) + +const ( + olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. " + olmo3NoFunctionsMessage = "You do not currently have access to any functions. " + olmo3WithFunctionsMessage = "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." +) + +type Olmo3Renderer struct{} + +func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + var systemMessage *api.Message + filteredMessages := make([]api.Message, 0, len(messages)) + for i, message := range messages { + if message.Role == "system" { + if systemMessage == nil { + systemMessage = &messages[i] + } + continue + } + filteredMessages = append(filteredMessages, message) + } + + // Render system message + if systemMessage != nil { + // Custom system message - single newline after "system" + sb.WriteString("<|im_start|>system\n") + sb.WriteString(systemMessage.Content) + + if len(tools) > 0 { + functionsJSON, err := marshalWithSpaces(tools) + if err != nil { + return "", err + } + sb.WriteString("") + sb.WriteString(string(functionsJSON)) + sb.WriteString("") + } + sb.WriteString("<|im_end|>\n") + } else { + // Default system message - single newline after "system" + sb.WriteString("<|im_start|>system\n") + sb.WriteString(olmo3DefaultSystemMessage) + + if len(tools) > 0 { + functionsJSON, err := marshalWithSpaces(tools) + if err != nil { + return "", err + } + sb.WriteString(olmo3WithFunctionsMessage) + sb.WriteString("") + sb.WriteString(string(functionsJSON)) + sb.WriteString("") + } else { + sb.WriteString(olmo3NoFunctionsMessage) + sb.WriteString("") + } + sb.WriteString("<|im_end|>\n") + } + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + + switch message.Role { + case "user": + sb.WriteString("<|im_start|>user\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + + case "assistant": + sb.WriteString("<|im_start|>assistant\n") + + if message.Content != "" { + sb.WriteString(message.Content) + } + + if len(message.ToolCalls) > 0 { + sb.WriteString("") + for j, tc := range message.ToolCalls { + // Format as function_name(arg1="value1", arg2="value2") + sb.WriteString(tc.Function.Name) + sb.WriteString("(") + + // Get sorted keys for deterministic output + keys := make([]string, 0, len(tc.Function.Arguments)) + for k := range tc.Function.Arguments { + keys = append(keys, k) + } + sort.Strings(keys) + + for k, key := range keys { + if k > 0 { + sb.WriteString(", ") + } + value, err := json.Marshal(tc.Function.Arguments[key]) + if err != nil { + return "", err + } + sb.WriteString(fmt.Sprintf("%s=%s", key, string(value))) + } + sb.WriteString(")") + + if j < len(message.ToolCalls)-1 { + sb.WriteString("\n") + } + } + sb.WriteString("") + } + + // Add end tag unless it's the last message with content only (prefill) + if !lastMessage || len(message.ToolCalls) > 0 { + sb.WriteString("<|im_end|>\n") + } + + case "tool": + sb.WriteString("<|im_start|>environment\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + } + } + + // Add generation prompt if needed + needsGenerationPrompt := true + if len(filteredMessages) > 0 { + lastMsg := filteredMessages[len(filteredMessages)-1] + if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" { + needsGenerationPrompt = false + } + } + + if needsGenerationPrompt { + sb.WriteString("<|im_start|>assistant\n\n") + } + + return sb.String(), nil +} diff --git a/model/renderers/olmo3_test.go b/model/renderers/olmo3_test.go new file mode 100644 index 00000000..56c79a23 --- /dev/null +++ b/model/renderers/olmo3_test.go @@ -0,0 +1,290 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestOlmo3Renderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic without system - adds default system", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful function-calling AI assistant. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "with system message no tools", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful assistant.<|im_end|>\n" + + "<|im_start|>user\n" + + "Hello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "with system message and tools", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What is the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + `You are a helpful assistant.[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "What is the weather?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "default system with tools - includes function instruction", + msgs: []api.Message{ + {Role: "user", Content: "What is the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + "You are a helpful function-calling AI assistant. " + + "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." + + `[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "What is the weather?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant with tool calls - function call syntax", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What is the weather in SF?"}, + { + Role: "assistant", + Content: "Let me check the weather.", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "location": "San Francisco", + }, + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + `You are a helpful assistant.[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "What is the weather in SF?<|im_end|>\n" + + "<|im_start|>assistant\n" + + `Let me check the weather.get_weather(location="San Francisco")<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 68}<|im_end|>` + "\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "multi-turn conversation", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + {Role: "user", Content: "How are you?"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful assistant.<|im_end|>\n" + + "<|im_start|>user\n" + + "Hello<|im_end|>\n" + + "<|im_start|>assistant\n" + + "Hi there!<|im_end|>\n" + + "<|im_start|>user\n" + + "How are you?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "parallel tool calls - newline separated", + msgs: []api.Message{ + {Role: "user", Content: "Get weather in SF and NYC"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "San Francisco"}, + }, + }, + { + ID: "call_2", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "New York"}, + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, + {Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + "You are a helpful function-calling AI assistant. " + + "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." + + `[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "Get weather in SF and NYC<|im_end|>\n" + + "<|im_start|>assistant\n" + + `get_weather(location="San Francisco")` + "\n" + + `get_weather(location="New York")<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 68}<|im_end|>` + "\n" + + "<|im_start|>environment\n" + + `{"temperature": 55}<|im_end|>` + "\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "tool call with multiple arguments", + msgs: []api.Message{ + {Role: "user", Content: "Book a flight"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + }, + }, + }, + }, + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "book_flight", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "from": {Type: api.PropertyType{"string"}}, + "to": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + expected: "<|im_start|>system\n" + + "You are a helpful function-calling AI assistant. " + + "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." + + `[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]<|im_end|>` + "\n" + + "<|im_start|>user\n" + + "Book a flight<|im_end|>\n" + + "<|im_start|>assistant\n" + + `book_flight(from="SFO", to="NYC")<|im_end|>` + "\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant prefill - no generation prompt", + msgs: []api.Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful function-calling AI assistant. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello<|im_end|>\n" + + "<|im_start|>assistant\n" + + "Hi there!", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 098b16a8..66c2f8de 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -59,6 +59,9 @@ func rendererForName(name string) Renderer { case "cogito": renderer := &CogitoRenderer{isThinking: true} return renderer + case "olmo3": + renderer := &Olmo3Renderer{} + return renderer case "olmo3-think": renderer := &Olmo3ThinkRenderer{} return renderer From 76f88caf437c3f48b68ff6df649e283f7370b27b Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:24:51 -0800 Subject: [PATCH 12/35] nomic-embed-text:v2: model implementation (#13162) --- convert/convert.go | 2 + convert/convert_nomicbert.go | 213 ++++++++++++++++++++++++++++++++ model/models/nomicbert/model.go | 107 +++++++++++++--- 3 files changed, 304 insertions(+), 18 deletions(-) create mode 100644 convert/convert_nomicbert.go diff --git a/convert/convert.go b/convert/convert.go index 15e31bf2..bc110c6f 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &qwen3VLModel{} case "BertModel": conv = &bertModel{} + case "NomicBertModel", "NomicBertMoEModel": + conv = &nomicbertModel{} case "CohereForCausalLM": conv = &commandrModel{} case "GptOssForCausalLM": diff --git a/convert/convert_nomicbert.go b/convert/convert_nomicbert.go new file mode 100644 index 00000000..6aed5ee7 --- /dev/null +++ b/convert/convert_nomicbert.go @@ -0,0 +1,213 @@ +package convert + +import ( + "cmp" + "encoding/json" + "io/fs" + "path/filepath" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type nomicbertModel struct { + ModelParameters + NLayers uint32 `json:"n_layers"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + LayerNormEPS float32 `json:"layer_norm_eps"` + LayerNormEpsilon float32 `json:"layer_norm_epsilon"` + RopeFreqBase float32 `json:"rope_theta"` + normalizeEmbeddings bool + PoolingType uint32 + + // MoE parameters (only present in v2 models) + NumExperts uint32 `json:"num_local_experts"` + NumExpertsUsed uint32 `json:"num_experts_per_tok"` + MoEEveryNLayers uint32 `json:"moe_every_n_layers"` +} + +var ( + _ ModelConverter = (*nomicbertModel)(nil) + _ moreParser = (*nomicbertModel)(nil) +) + +func (p *nomicbertModel) parseMore(fsys fs.FS) error { + bts, err := fs.ReadFile(fsys, "modules.json") + if err != nil { + return err + } + + var modules []struct { + Type string `json:"type"` + Path string `json:"path"` + } + + if err := json.Unmarshal(bts, &modules); err != nil { + return err + } + + var pooling string + for _, m := range modules { + switch m.Type { + case "sentence_transformers.models.Pooling": + pooling = m.Path + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true + } + } + + if pooling != "" { + bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json")) + if err != nil { + return err + } + + var pc struct { + PoolingModeCLSToken bool `json:"pooling_mode_cls_token"` + PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"` + } + + if err := json.Unmarshal(bts, &pc); err != nil { + return err + } + + if pc.PoolingModeMeanTokens { + p.PoolingType = 1 + } else if pc.PoolingModeCLSToken { + p.PoolingType = 2 + } + } + + return nil +} + +func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + + // Determine architecture based on MoE parameters (following qwen3 pattern) + arch := "nomic-bert" + if p.MoEEveryNLayers > 0 { + arch += "-moe" + } + + kv["general.architecture"] = arch + kv["attention.causal"] = false + kv["pooling_type"] = p.PoolingType + kv["normalize_embeddings"] = p.normalizeEmbeddings + + kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers) + + if contextLength := p.MaxPositionEmbeddings; contextLength > 0 { + kv["context_length"] = contextLength + } + + if embeddingLength := p.HiddenSize; embeddingLength > 0 { + kv["embedding_length"] = p.HiddenSize + } + + if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 { + kv["feed_forward_length"] = p.IntermediateSize + } + + if headCount := p.NumAttentionHeads; headCount > 0 { + kv["attention.head_count"] = p.NumAttentionHeads + } + + if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 { + kv["attention.head_count_kv"] = p.NumKeyValueHeads + } + + if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 { + kv["attention.layer_norm_epsilon"] = layerNormEpsilon + } + + if p.RopeFreqBase > 0 { + kv["rope.freq_base"] = p.RopeFreqBase + } + + // MoE specific parameters (only if MoE is enabled) + if p.NumExperts > 0 { + kv["expert_count"] = p.NumExperts + } + + if p.NumExpertsUsed > 0 { + kv["expert_used_count"] = p.NumExpertsUsed + } + + if p.MoEEveryNLayers > 0 { + kv["moe_every_n_layers"] = p.MoEEveryNLayers + } + + kv["tokenizer.ggml.model"] = "bert" + kv["tokenizer.ggml.token_type_count"] = uint32(2) + + // convert to phantom space tokens + for i, e := range t.Tokens { + switch { + case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"): + // noop - keep special tokens as-is + case strings.HasPrefix(e, "##"): + t.Tokens[i] = e[2:] + default: + t.Tokens[i] = "\u2581" + e + } + } + + kv["tokenizer.ggml.tokens"] = t.Tokens + + return kv +} + +func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor { + out := make([]*ggml.Tensor, 0, len(ts)) + for _, t := range ts { + if slices.Contains([]string{ + "embeddings.position_ids", + "pooler.dense.weight", + "pooler.dense.bias", + }, t.Name()) { + continue + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (nomicbertModel) Replacements() []string { + return []string{ + "encoder.layer", "blk", + "encoder.layers", "blk", + "embeddings.word_embeddings", "token_embd", + "embeddings.token_type_embeddings", "token_types", + "embeddings.LayerNorm", "token_embd_norm", + + "attention.self.qkv", "attn_qkv", + + "attention.output.dense", "attn_output", + "attention.output.LayerNorm", "attn_output_norm", + + "mlp.up", "ffn_up", + "mlp.down", "ffn_down", + + "mlp.router", "ffn_gate_inp", + "mlp.experts.up", "ffn_up_exps", + "mlp.experts.down", "ffn_down_exps", + + "intermediate.dense", "ffn_up", + "output.dense", "ffn_down", + "output.LayerNorm", "layer_output_norm", + } +} diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go index 2510240d..096d046a 100644 --- a/model/models/nomicbert/model.go +++ b/model/models/nomicbert/model.go @@ -34,19 +34,23 @@ type Options struct { poolingType pooling.Type normalize bool ropeFreqBase float32 + + // MoE specific options (used by v2 / MoE models only) + numExperts int + numExpertsUsed int + moeEveryNLayers int } func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX()) } -// Single Encoder Layer type EncoderLayer struct { *Attention AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` - *MLP + FeedForward FeedForward MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` } @@ -56,12 +60,63 @@ type Attention struct { Output *nn.Linear `gguf:"attn_output"` } -type MLP struct { +type FeedForward interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type dense struct { Gate *nn.Linear `gguf:"ffn_gate"` Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` } +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + hidden := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hidden) +} + +// denseGELU implements MLP with GELU activation for v2 MoE dense layers +type denseGELU struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *denseGELU) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +// sparse implements MoE with expert routing +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + + routerLogits := moe.Router.Forward(ctx, hiddenStates) + routingWeights := routerLogits.Softmax(ctx) + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + hiddenStates = moe.Up.Forward(ctx, hiddenStates, selectedExperts).GELU(ctx) + experts := moe.Down.Forward(ctx, hiddenStates, selectedExperts) + + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates +} + func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -92,7 +147,7 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) residual = hiddenStates - hiddenStates = e.MLP.Forward(ctx, hiddenStates) + hiddenStates = e.FeedForward.Forward(ctx, hiddenStates, opts) hiddenStates = hiddenStates.Add(ctx, residual) hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) @@ -118,12 +173,6 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml return a.Output.Forward(ctx, attention) } -func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - hidden := m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates)) - - return m.Down.Forward(ctx, hidden) -} - func New(c fs.Config) (model.Model, error) { hiddenSize := int(c.Uint("embedding_length")) numHeads := int(c.Uint("attention.head_count")) @@ -152,17 +201,37 @@ func New(c fs.Config) (model.Model, error) { false, ) + blockCount := int(c.Uint("block_count")) + moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0)) + layers := make([]EncoderLayer, blockCount) + + for i := range layers { + if moeEveryNLayers > 0 { + // Layer uses MoE if (i+1) % moe_every_n_layers == 0 + if (i+1)%moeEveryNLayers == 0 { + layers[i].FeedForward = &sparse{} + } else { + layers[i].FeedForward = &denseGELU{} + } + } else { + layers[i].FeedForward = &dense{} + } + } + return &Model{ TextProcessor: processor, - Layers: make([]EncoderLayer, c.Uint("block_count")), + Layers: layers, Options: Options{ - hiddenSize: hiddenSize, - numHeads: numHeads, - headDim: headDim, - eps: c.Float("attention.layer_norm_epsilon"), - poolingType: pooling.Type(c.Uint("pooling_type")), - normalize: c.Bool("normalize_embeddings", false), - ropeFreqBase: c.Float("rope.freq_base", 1000.0), + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: headDim, + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", false), + ropeFreqBase: c.Float("rope.freq_base", 1000.0), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + moeEveryNLayers: moeEveryNLayers, }, }, nil } @@ -170,4 +239,6 @@ func New(c fs.Config) (model.Model, error) { func init() { model.Register("nomic-bert", New) model.Register("nomic-bert_embed", New) + model.Register("nomic-bert-moe", New) + model.Register("nomic-bert-moe_embed", New) } From bbbb6b2a013b3fb27ddd7d7132a9dd6026f3e7ad Mon Sep 17 00:00:00 2001 From: Eva H <63033505+hoyyeva@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:40:02 -0500 Subject: [PATCH 13/35] app/ui: fix model capabilities not updating after download completion (#13179) --- app/ui/app/src/hooks/useChats.ts | 19 ++++ app/ui/app/src/hooks/useDownloadModel.ts | 114 ----------------------- 2 files changed, 19 insertions(+), 114 deletions(-) delete mode 100644 app/ui/app/src/hooks/useDownloadModel.ts diff --git a/app/ui/app/src/hooks/useChats.ts b/app/ui/app/src/hooks/useChats.ts index b7517253..410a80e7 100644 --- a/app/ui/app/src/hooks/useChats.ts +++ b/app/ui/app/src/hooks/useChats.ts @@ -7,6 +7,7 @@ import { createQueryBatcher } from "./useQueryBatcher"; import { useRefetchModels } from "./useModels"; import { useStreamingContext } from "@/contexts/StreamingContext"; import { useSettings } from "./useSettings"; +import { getModelCapabilities } from "@/api"; export const useChats = () => { return useQuery({ @@ -606,6 +607,24 @@ export const useSendMessage = (chatId: string) => { queryClient.setQueryData(["staleModels"], newStaleMap); queryClient.invalidateQueries({ queryKey: ["models"] }); + + // Fetch fresh capabilities for the downloaded model + getModelCapabilities(selectedModel.model) + .then((capabilities) => { + queryClient.setQueryData( + ["modelCapabilities", selectedModel.model], + capabilities, + ); + }) + .catch((error) => { + console.error( + "Failed to fetch capabilities after download:", + error, + ); + queryClient.invalidateQueries({ + queryKey: ["modelCapabilities", selectedModel.model], + }); + }); } break; } diff --git a/app/ui/app/src/hooks/useDownloadModel.ts b/app/ui/app/src/hooks/useDownloadModel.ts deleted file mode 100644 index aa69edec..00000000 --- a/app/ui/app/src/hooks/useDownloadModel.ts +++ /dev/null @@ -1,114 +0,0 @@ -import { useMutation, useQueryClient } from "@tanstack/react-query"; -import { useState } from "react"; -import { pullModel } from "@/api"; -import { useSelectedModel } from "./useSelectedModel"; -import { useSettings } from "./useSettings"; - -interface DownloadProgress { - status: string; - digest?: string; - total?: number; - completed?: number; - done?: boolean; -} - -export function useDownloadModel(chatId?: string) { - const queryClient = useQueryClient(); - const { selectedModel } = useSelectedModel(chatId); - const { setSettings } = useSettings(); - const [downloadProgress, setDownloadProgress] = - useState(null); - const [abortController, setAbortController] = - useState(null); - const [downloadingChatIds, setDownloadingChatIds] = useState>( - new Set(), - ); - - const mutation = useMutation({ - mutationFn: async (modelName: string) => { - const controller = new AbortController(); - setAbortController(controller); - setDownloadProgress({ status: "Starting download..." }); - if (chatId) { - setDownloadingChatIds((prev) => new Set(prev).add(chatId)); - } - - try { - for await (const progress of pullModel(modelName, controller.signal)) { - setDownloadProgress(progress); - - if (progress.status === "success") { - // Update selected model to indicate it's now available locally - if (selectedModel && selectedModel.model === modelName) { - setSettings({ SelectedModel: modelName }); - } - // Invalidate models query to refresh the list - await queryClient.invalidateQueries({ queryKey: ["models"] }); - break; - } - } - } finally { - setAbortController(null); - if (chatId) { - setDownloadingChatIds((prev) => { - const newSet = new Set(prev); - newSet.delete(chatId); - return newSet; - }); - } - } - }, - onSuccess: () => { - setDownloadProgress(null); - if (chatId) { - setDownloadingChatIds((prev) => { - const newSet = new Set(prev); - newSet.delete(chatId); - return newSet; - }); - } - }, - onError: (error: Error) => { - const status = - error.name === "AbortError" ? "Download cancelled" : "Download failed"; - setDownloadProgress({ status, done: true }); - - // Clear error message after delay - const delay = error.name === "AbortError" ? 1500 : 3000; - setTimeout(() => { - setDownloadProgress(null); - if (chatId) { - setDownloadingChatIds((prev) => { - const newSet = new Set(prev); - newSet.delete(chatId); - return newSet; - }); - } - }, delay); - }, - }); - - const cancelDownload = () => { - if (abortController) { - abortController.abort(); - setAbortController(null); - if (chatId) { - setDownloadingChatIds((prev) => { - const newSet = new Set(prev); - newSet.delete(chatId); - return newSet; - }); - } - } - }; - - return { - downloadModel: mutation.mutate, - isDownloading: - mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false, - downloadProgress: - chatId && downloadingChatIds.has(chatId) ? downloadProgress : null, - error: mutation.error, - cancelDownload, - }; -} From 7cf6f18c1fdeb3b3dac19fabc8e83817026a215c Mon Sep 17 00:00:00 2001 From: Eva H <63033505+hoyyeva@users.noreply.github.com> Date: Wed, 10 Dec 2025 15:24:31 -0500 Subject: [PATCH 14/35] app/ui: refactor to use Ollama endpoints for user auth and health checks (#13081) --- app/cmd/app/app.go | 17 ++- app/ui/app/codegen/gotypes.gen.ts | 24 ++-- app/ui/app/src/api.ts | 71 ++++++----- app/ui/app/src/components/Settings.tsx | 4 +- app/ui/app/src/hooks/useUser.ts | 32 ++--- app/ui/app/src/lib/config.ts | 3 + app/ui/app/src/main.tsx | 29 +---- app/ui/responses/types.go | 17 ++- app/ui/ui.go | 159 +++---------------------- 9 files changed, 106 insertions(+), 250 deletions(-) diff --git a/app/cmd/app/app.go b/app/cmd/app/app.go index d09e04b5..7e183b8d 100644 --- a/app/cmd/app/app.go +++ b/app/cmd/app/app.go @@ -273,10 +273,6 @@ func main() { Handler: uiServer.Handler(), } - if _, err := uiServer.UserData(ctx); err != nil { - slog.Warn("failed to load user data", "error", err) - } - // Start the UI server slog.Info("starting ui server", "port", port) go func() { @@ -320,6 +316,17 @@ func main() { slog.Debug("no URL scheme request to handle") } + go func() { + slog.Debug("waiting for ollama server to be ready") + if err := ui.WaitForServer(ctx, 10*time.Second); err != nil { + slog.Warn("ollama server not ready, continuing anyway", "error", err) + } + + if _, err := uiServer.UserData(ctx); err != nil { + slog.Warn("failed to load user data", "error", err) + } + }() + osRun(cancel, hasCompletedFirstRun, startHidden) slog.Info("shutting down desktop server") @@ -361,7 +368,7 @@ func checkUserLoggedIn(uiServerPort int) bool { return false } - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort)) + resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil) if err != nil { slog.Debug("failed to call local auth endpoint", "error", err) return false diff --git a/app/ui/app/codegen/gotypes.gen.ts b/app/ui/app/codegen/gotypes.gen.ts index a077c854..0bf86f2b 100644 --- a/app/ui/app/codegen/gotypes.gen.ts +++ b/app/ui/app/codegen/gotypes.gen.ts @@ -469,26 +469,24 @@ export class HealthResponse { } export class User { id: string; - name: string; email: string; - avatarURL: string; - plan: string; - bio: string; - firstName: string; - lastName: string; - overThreshold: boolean; + name: string; + bio?: string; + avatarurl?: string; + firstname?: string; + lastname?: string; + plan?: string; constructor(source: any = {}) { if ('string' === typeof source) source = JSON.parse(source); this.id = source["id"]; - this.name = source["name"]; this.email = source["email"]; - this.avatarURL = source["avatarURL"]; - this.plan = source["plan"]; + this.name = source["name"]; this.bio = source["bio"]; - this.firstName = source["firstName"]; - this.lastName = source["lastName"]; - this.overThreshold = source["overThreshold"]; + this.avatarurl = source["avatarurl"]; + this.firstname = source["firstname"]; + this.lastname = source["lastname"]; + this.plan = source["plan"]; } } export class Attachment { diff --git a/app/ui/app/src/api.ts b/app/ui/app/src/api.ts index a701a30a..273850d6 100644 --- a/app/ui/app/src/api.ts +++ b/app/ui/app/src/api.ts @@ -15,7 +15,7 @@ import { import { parseJsonlFromResponse } from "./util/jsonl-parsing"; import { ollamaClient as ollama } from "./lib/ollama-client"; import type { ModelResponse } from "ollama/browser"; -import { API_BASE } from "./lib/config"; +import { API_BASE, OLLAMA_DOT_COM } from "./lib/config"; // Extend Model class with utility methods declare module "@/gotypes" { @@ -27,7 +27,6 @@ declare module "@/gotypes" { Model.prototype.isCloud = function (): boolean { return this.model.endsWith("cloud"); }; - // Helper function to convert Uint8Array to base64 function uint8ArrayToBase64(uint8Array: Uint8Array): string { const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow @@ -42,44 +41,50 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string { } export async function fetchUser(): Promise { - try { - const response = await fetch(`${API_BASE}/api/v1/me`, { - method: "GET", - headers: { - "Content-Type": "application/json", - }, - }); - - if (response.ok) { - const userData: User = await response.json(); - return userData; - } - - return null; - } catch (error) { - console.error("Error fetching user:", error); - return null; - } -} - -export async function fetchConnectUrl(): Promise { - const response = await fetch(`${API_BASE}/api/v1/connect`, { - method: "GET", + const response = await fetch(`${API_BASE}/api/me`, { + method: "POST", headers: { "Content-Type": "application/json", }, }); - if (!response.ok) { - throw new Error("Failed to fetch connect URL"); + if (response.ok) { + const userData: User = await response.json(); + + if (userData.avatarurl && !userData.avatarurl.startsWith("http")) { + userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`; + } + + return userData; } - const data = await response.json(); - return data.connect_url; + if (response.status === 401 || response.status === 403) { + return null; + } + + throw new Error(`Failed to fetch user: ${response.status}`); +} + +export async function fetchConnectUrl(): Promise { + const response = await fetch(`${API_BASE}/api/me`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (response.status === 401) { + const data = await response.json(); + if (data.signin_url) { + return data.signin_url; + } + } + + throw new Error("Failed to fetch connect URL"); } export async function disconnectUser(): Promise { - const response = await fetch(`${API_BASE}/api/v1/disconnect`, { + const response = await fetch(`${API_BASE}/api/signout`, { method: "POST", headers: { "Content-Type": "application/json", @@ -389,7 +394,8 @@ export async function getInferenceCompute(): Promise { export async function fetchHealth(): Promise { try { - const response = await fetch(`${API_BASE}/api/v1/health`, { + // Use the /api/version endpoint as a health check + const response = await fetch(`${API_BASE}/api/version`, { method: "GET", headers: { "Content-Type": "application/json", @@ -398,7 +404,8 @@ export async function fetchHealth(): Promise { if (response.ok) { const data = await response.json(); - return data.healthy || false; + // If we get a version back, the server is healthy + return !!data.version; } return false; diff --git a/app/ui/app/src/components/Settings.tsx b/app/ui/app/src/components/Settings.tsx index c56a97b3..057f7477 100644 --- a/app/ui/app/src/components/Settings.tsx +++ b/app/ui/app/src/components/Settings.tsx @@ -299,9 +299,9 @@ export default function Settings() { - {user?.avatarURL && ( + {user?.avatarurl && ( {user?.name} { diff --git a/app/ui/app/src/hooks/useUser.ts b/app/ui/app/src/hooks/useUser.ts index 5f7a4dad..b4e6698e 100644 --- a/app/ui/app/src/hooks/useUser.ts +++ b/app/ui/app/src/hooks/useUser.ts @@ -1,29 +1,20 @@ import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; -import { useEffect, useState } from "react"; import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api"; export function useUser() { const queryClient = useQueryClient(); - const [initialDataLoaded, setInitialDataLoaded] = useState(false); - - // Wait for initial data to be loaded - useEffect(() => { - const initialPromise = window.__initialUserDataPromise; - if (initialPromise) { - initialPromise.finally(() => { - setInitialDataLoaded(true); - }); - } else { - setInitialDataLoaded(true); - } - }, []); const userQuery = useQuery({ queryKey: ["user"], - queryFn: () => fetchUser(), + queryFn: async () => { + const result = await fetchUser(); + return result; + }, staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes - initialData: null, // Start with null to prevent flashing + retry: 10, + retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000), + refetchOnMount: true, // Always fetch when component mounts }); // Mutation to refresh user data @@ -49,14 +40,15 @@ export function useUser() { }, }); + const isLoading = userQuery.isLoading || userQuery.isFetching; + const isAuthenticated = Boolean(userQuery.data?.name); + return { user: userQuery.data, - isLoading: - !initialDataLoaded || - (userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded + isLoading, isError: userQuery.isError, error: userQuery.error, - isAuthenticated: Boolean(userQuery.data?.name), + isAuthenticated, refreshUser: refreshUser.mutate, isRefreshing: refreshUser.isPending, refetchUser: userQuery.refetch, diff --git a/app/ui/app/src/lib/config.ts b/app/ui/app/src/lib/config.ts index c1124396..7c5385d7 100644 --- a/app/ui/app/src/lib/config.ts +++ b/app/ui/app/src/lib/config.ts @@ -8,3 +8,6 @@ export const API_BASE = import.meta.env.DEV ? DEV_API_URL : ""; export const OLLAMA_HOST = import.meta.env.DEV ? DEV_API_URL : window.location.origin; + +export const OLLAMA_DOT_COM = + import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com"; diff --git a/app/ui/app/src/main.tsx b/app/ui/app/src/main.tsx index 1ffe37ef..3e325a3c 100644 --- a/app/ui/app/src/main.tsx +++ b/app/ui/app/src/main.tsx @@ -5,13 +5,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { routeTree } from "./routeTree.gen"; import { fetchUser } from "./api"; import { StreamingProvider } from "./contexts/StreamingContext"; -import { User } from "@/gotypes"; - -declare global { - interface Window { - __initialUserDataPromise?: Promise; - } -} const queryClient = new QueryClient({ defaultOptions: { @@ -24,27 +17,11 @@ const queryClient = new QueryClient({ }, }); -// Track initial user data fetch -let initialUserDataPromise: Promise | null = null; - -// Initialize user data on app startup -const initializeUserData = async () => { - try { - const userData = await fetchUser(); +fetchUser().then((userData) => { + if (userData) { queryClient.setQueryData(["user"], userData); - return userData; - } catch (error) { - console.error("Error initializing user data:", error); - queryClient.setQueryData(["user"], null); - return null; } -}; - -// Start initialization immediately and track the promise -initialUserDataPromise = initializeUserData(); - -// Export the promise so hooks can await it -window.__initialUserDataPromise = initialUserDataPromise; +}); const router = createRouter({ routeTree, diff --git a/app/ui/responses/types.go b/app/ui/responses/types.go index 438dd55e..2da6623f 100644 --- a/app/ui/responses/types.go +++ b/app/ui/responses/types.go @@ -101,15 +101,14 @@ type HealthResponse struct { } type User struct { - ID string `json:"id"` - Name string `json:"name"` - Email string `json:"email"` - AvatarURL string `json:"avatarURL"` - Plan string `json:"plan"` - Bio string `json:"bio"` - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - OverThreshold bool `json:"overThreshold"` + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Bio string `json:"bio,omitempty"` + AvatarURL string `json:"avatarurl,omitempty"` + FirstName string `json:"firstname,omitempty"` + LastName string `json:"lastname,omitempty"` + Plan string `json:"plan,omitempty"` } type Attachment struct { diff --git a/app/ui/ui.go b/app/ui/ui.go index 1d0e2579..5a64705d 100644 --- a/app/ui/ui.go +++ b/app/ui/ui.go @@ -23,7 +23,6 @@ import ( "github.com/google/uuid" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/app/auth" "github.com/ollama/ollama/app/server" "github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/tools" @@ -264,11 +263,10 @@ func (s *Server) Handler() http.Handler { ollamaProxy := s.ollamaProxy() mux.Handle("GET /api/tags", ollamaProxy) mux.Handle("POST /api/show", ollamaProxy) - - mux.Handle("GET /api/v1/me", handle(s.me)) - mux.Handle("POST /api/v1/disconnect", handle(s.disconnect)) - mux.Handle("GET /api/v1/connect", handle(s.connectURL)) - mux.Handle("GET /api/v1/health", handle(s.health)) + mux.Handle("GET /api/version", ollamaProxy) + mux.Handle("HEAD /api/version", ollamaProxy) + mux.Handle("POST /api/me", ollamaProxy) + mux.Handle("POST /api/signout", ollamaProxy) // React app - catch all non-API routes and serve the React app mux.Handle("GET /", s.appHandler()) @@ -338,7 +336,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R } // UserData fetches user data from ollama.com API for the current ollama key -func (s *Server) UserData(ctx context.Context) (*responses.User, error) { +func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) { resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me") if err != nil { return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err) @@ -349,7 +347,7 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) { return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } - var user responses.User + var user api.UserResponse if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { return nil, fmt.Errorf("failed to parse user response: %w", err) } @@ -368,29 +366,27 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) { return &user, nil } -func waitForServer(ctx context.Context) error { - timeout := time.Now().Add(10 * time.Second) - // TODO: this avoids an error on first load of the app - // however we should either show a loading state or - // wait for the Ollama server to be ready before redirecting - for { +// WaitForServer waits for the Ollama server to be ready +func WaitForServer(ctx context.Context, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { c, err := api.ClientFromEnvironment() if err != nil { return err } if _, err := c.Version(ctx); err == nil { - break - } - if time.Now().After(timeout) { - return fmt.Errorf("timeout waiting for Ollama server to be ready") + slog.Debug("ollama server is ready") + return nil } time.Sleep(10 * time.Millisecond) } - return nil + return errors.New("timeout waiting for Ollama server to be ready") } func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error { - waitForServer(r.Context()) + if err := WaitForServer(r.Context(), 10*time.Second); err != nil { + return err + } id, err := uuid.NewV7() if err != nil { @@ -1438,129 +1434,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error { }) } -func (s *Server) me(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodGet { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return nil - } - - user, err := s.UserData(r.Context()) - if err != nil { - // If fetching from API fails, try to return cached user data if available - if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil { - s.log().Info("API request failed, returning cached user data", "error", err) - responseUser := &responses.User{ - Name: cachedUser.Name, - Email: cachedUser.Email, - Plan: cachedUser.Plan, - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(responseUser) - } - - s.log().Error("failed to get user data", "error", err) - w.WriteHeader(http.StatusInternalServerError) - return json.NewEncoder(w).Encode(responses.Error{ - Error: "failed to get user data", - }) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(user) -} - -func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodPost { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return nil - } - - if err := s.Store.ClearUser(); err != nil { - s.log().Warn("failed to clear cached user data", "error", err) - } - - // Get the SSH public key to encode for the delete request - pubKey, err := ollamaAuth.GetPublicKey() - if err != nil { - s.log().Error("failed to get public key", "error", err) - w.WriteHeader(http.StatusInternalServerError) - return json.NewEncoder(w).Encode(responses.Error{ - Error: "failed to get public key", - }) - } - - // Encode the key using base64 URL encoding - encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - - // Call the /api/user/keys/{encodedKey} endpoint with DELETE - resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey)) - if err != nil { - s.log().Error("failed to call ollama.com/api/user/keys", "error", err) - w.WriteHeader(http.StatusInternalServerError) - return json.NewEncoder(w).Encode(responses.Error{ - Error: "failed to disconnect from ollama.com", - }) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - s.log().Error("disconnect request failed", "status", resp.StatusCode) - w.WriteHeader(http.StatusInternalServerError) - return json.NewEncoder(w).Encode(responses.Error{ - Error: "failed to disconnect from ollama.com", - }) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) -} - -func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodGet { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return nil - } - - connectURL, err := auth.BuildConnectURL(OllamaDotCom) - if err != nil { - s.log().Error("failed to build connect URL", "error", err) - w.WriteHeader(http.StatusInternalServerError) - return json.NewEncoder(w).Encode(responses.Error{ - Error: "failed to build connect URL", - }) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(map[string]string{ - "connect_url": connectURL, - }) -} - -func (s *Server) health(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodGet { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return nil - } - - healthy := false - c, err := api.ClientFromEnvironment() - if err == nil { - if _, err := c.Version(r.Context()); err == nil { - healthy = true - } - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(responses.HealthResponse{ - Healthy: healthy, - }) -} - func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error { ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond) defer cancel() From c34fc64688396bb9f2e17a2a89251a63f58ee4f7 Mon Sep 17 00:00:00 2001 From: Eva H <63033505+hoyyeva@users.noreply.github.com> Date: Wed, 10 Dec 2025 15:29:48 -0500 Subject: [PATCH 15/35] app/ui: use requestAnimationFrame to prevent bottom line cutoff in streaming thinking display (#13137) --- app/ui/app/src/components/Thinking.tsx | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/app/ui/app/src/components/Thinking.tsx b/app/ui/app/src/components/Thinking.tsx index 7ab23e72..e8236424 100644 --- a/app/ui/app/src/components/Thinking.tsx +++ b/app/ui/app/src/components/Thinking.tsx @@ -50,21 +50,33 @@ export default function Thinking({ // Position content to show bottom when collapsed useEffect(() => { if (isCollapsed && contentRef.current && wrapperRef.current) { - const contentHeight = contentRef.current.scrollHeight; - const wrapperHeight = wrapperRef.current.clientHeight; - if (contentHeight > wrapperHeight) { - const translateY = -(contentHeight - wrapperHeight); - contentRef.current.style.transform = `translateY(${translateY}px)`; - setHasOverflow(true); - } else { - setHasOverflow(false); - } + requestAnimationFrame(() => { + if (!contentRef.current || !wrapperRef.current) return; + + const contentHeight = contentRef.current.scrollHeight; + const wrapperHeight = wrapperRef.current.clientHeight; + if (contentHeight > wrapperHeight) { + const translateY = -(contentHeight - wrapperHeight); + contentRef.current.style.transform = `translateY(${translateY}px)`; + setHasOverflow(true); + } else { + contentRef.current.style.transform = "translateY(0)"; + setHasOverflow(false); + } + }); } else if (contentRef.current) { contentRef.current.style.transform = "translateY(0)"; setHasOverflow(false); } }, [thinking, isCollapsed]); + useEffect(() => { + if (activelyThinking && wrapperRef.current && !isCollapsed) { + // When expanded and actively thinking, scroll to bottom + wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight; + } + }, [thinking, activelyThinking, isCollapsed]); + const handleToggle = () => { setIsCollapsed(!isCollapsed); setHasUserInteracted(true); From b95693056c62a3acaa172f555e3bdb2b438684e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 10 Dec 2025 13:59:27 -0700 Subject: [PATCH 16/35] feat: llama.cpp bump (17f7f4) for SSM performance improvements (#13408) * feat: Bump llama.cpp to the latest master (17f7f4b) This brings in significant improvements to prefill performance for all models using the SSM_CONV and SSM_SCAN ops (granite4, jamba, falcon-h, nemotron-h, Qwen3 Next) on Apple Metal. See https://github.com/ggml-org/llama.cpp/pull/17876 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * feat: Update patches 1-4 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * fix: Update patches 5-12 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * feat: Update patches 13-18 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * feat: Update patch 20 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * feat: Update patches 21-31 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart * feat: Sync vendored code The two files I'm not sure about here are the swap from gemma3-iswa.cpp to gemma3.cpp (I chose to include this because I think it's required), and the inclusion of `ggml-zendnn.h` which I chose to omit. Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart --- Makefile.sync | 2 +- llama/build-info.cpp | 2 +- llama/llama.cpp/common/common.cpp | 77 +- llama/llama.cpp/common/common.h | 28 +- .../common/json-schema-to-grammar.cpp | 2 +- llama/llama.cpp/common/log.cpp | 45 +- llama/llama.cpp/common/log.h | 31 +- llama/llama.cpp/src/llama-arch.cpp | 31 +- llama/llama.cpp/src/llama-arch.h | 3 + llama/llama.cpp/src/llama-context.cpp | 12 +- llama/llama.cpp/src/llama-context.h | 2 +- llama/llama.cpp/src/llama-grammar.cpp | 265 +++- llama/llama.cpp/src/llama-grammar.h | 21 +- llama/llama.cpp/src/llama-graph.cpp | 11 +- llama/llama.cpp/src/llama-hparams.h | 4 +- llama/llama.cpp/src/llama-impl.h | 2 +- llama/llama.cpp/src/llama-mmap.cpp | 2 +- llama/llama.cpp/src/llama-model.cpp | 88 +- llama/llama.cpp/src/llama-quant.cpp | 29 - llama/llama.cpp/src/llama-vocab.cpp | 3 +- llama/llama.cpp/src/models/deepseek2.cpp | 18 + .../models/{gemma3-iswa.cpp => gemma3.cpp} | 35 +- llama/llama.cpp/src/models/mistral3.cpp | 160 +++ llama/llama.cpp/src/models/models.h | 9 +- llama/llama.cpp/src/unicode.cpp | 4 +- llama/llama.cpp/tools/mtmd/clip.cpp | 70 +- llama/llama.cpp/tools/mtmd/clip.h | 1 + llama/llama.cpp/tools/mtmd/mtmd.cpp | 6 + llama/llama.cpp/tools/mtmd/mtmd.h | 1 + ...loc-and-free-using-the-same-compiler.patch | 32 +- llama/patches/0002-pretokenizer.patch | 2 +- llama/patches/0003-clip-unicode.patch | 6 +- llama/patches/0004-solar-pro.patch | 36 +- .../0005-fix-deepseek-deseret-regex.patch | 4 +- ...ntain-ordering-for-rules-for-grammar.patch | 2 +- .../patches/0007-sort-devices-by-score.patch | 14 +- ...target-ggml-cpu-for-all-cpu-variants.patch | 6 +- llama/patches/0009-remove-amx.patch | 4 +- .../0010-fix-string-arr-kv-loading.patch | 4 +- llama/patches/0011-ollama-debug-tensor.patch | 4 +- ...add-ollama-vocab-for-grammar-support.patch | 43 +- ...13-add-argsort-and-cuda-copy-for-i32.patch | 12 +- ...14-graph-memory-reporting-on-failure.patch | 18 +- .../patches/0015-ggml-Export-GPU-UUIDs.patch | 14 +- .../0016-add-C-API-for-mtmd_input_text.patch | 4 +- ...-no-power-throttling-win32-with-gnuc.patch | 4 +- .../0018-ggml-Add-batch-size-hint.patch | 30 +- llama/patches/0020-ggml-No-alloc-mode.patch | 60 +- .../0021-decode-disable-output_all.patch | 2 +- ...gml-Enable-resetting-backend-devices.patch | 12 +- .../0024-GPU-discovery-enhancements.patch | 42 +- .../0026-report-LoadLibrary-failures.patch | 4 +- .../patches/0027-interleave-multi-rope.patch | 6 +- ...-Add-memory-detection-using-DXGI-PDH.patch | 16 +- .../0029-ggml-cuda-skip-large-batches.patch | 4 +- .../0030-win-exit-instead-of-abort.patch | 4 +- .../0031-fix-bakllava-regression.patch | 4 +- ml/backend/ggml/ggml/include/ggml-rpc.h | 3 +- ml/backend/ggml/ggml/include/ggml.h | 31 +- ml/backend/ggml/ggml/src/CMakeLists.txt | 5 +- ml/backend/ggml/ggml/src/ggml-alloc.c | 1 + ml/backend/ggml/ggml/src/ggml-backend-reg.cpp | 22 +- ml/backend/ggml/ggml/src/ggml-backend.cpp | 43 +- .../ggml/ggml/src/ggml-cpu/CMakeLists.txt | 3 + .../ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp | 4 + .../ggml/src/ggml-cpu/arch/arm/repack.cpp | 2 - ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c | 34 +- .../ggml/src/ggml-cpu/llamafile/sgemm.cpp | 176 +-- .../ggml/ggml/src/ggml-cpu/llamafile/sgemm.h | 6 + ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp | 110 +- ml/backend/ggml/ggml/src/ggml-cuda/common.cuh | 58 +- ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu | 237 +++ ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cuh | 5 + ml/backend/ggml/ggml/src/ggml-cuda/diag.cu | 77 + ml/backend/ggml/ggml/src/ggml-cuda/diag.cuh | 5 + .../ggml/ggml/src/ggml-cuda/fattn-common.cuh | 28 +- .../ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1279 +++++++++-------- .../ggml/ggml/src/ggml-cuda/fattn-tile.cuh | 49 +- .../ggml/ggml/src/ggml-cuda/fattn-vec.cuh | 20 +- .../ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu | 19 +- .../ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 4 +- ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu | 38 +- ml/backend/ggml/ggml/src/ggml-cuda/fill.cu | 37 + ml/backend/ggml/ggml/src/ggml-cuda/fill.cuh | 3 + .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 90 +- ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh | 378 +++-- ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu | 4 +- ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh | 59 +- ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu | 7 +- ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh | 6 +- ml/backend/ggml/ggml/src/ggml-cuda/pad.cu | 99 +- .../ggml/ggml/src/ggml-cuda/solve_tri.cu | 68 +- ml/backend/ggml/ggml/src/ggml-cuda/tri.cu | 136 ++ ml/backend/ggml/ggml/src/ggml-cuda/tri.cuh | 5 + ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu | 81 +- .../ggml/src/ggml-metal/ggml-metal-context.m | 42 +- .../ggml/src/ggml-metal/ggml-metal-device.cpp | 794 +++++----- .../ggml/src/ggml-metal/ggml-metal-device.h | 146 +- .../ggml/src/ggml-metal/ggml-metal-device.m | 352 +++-- .../src/ggml-metal/ggml-metal-embed.metal | 283 +++- .../ggml/src/ggml-metal/ggml-metal-impl.h | 24 + .../ggml/src/ggml-metal/ggml-metal-ops.cpp | 284 ++-- .../ggml/ggml/src/ggml-metal/ggml-metal-ops.h | 2 + .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 259 +++- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 587 ++++---- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 21 +- .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 86 +- .../src/ggml-vulkan/vulkan-shaders/pad.comp | 25 +- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 6 +- .../ggml-vulkan/vulkan-shaders/solve_tri.comp | 67 +- .../vulkan-shaders/topk_argsort.comp | 19 +- .../ggml-vulkan/vulkan-shaders/topk_moe.comp | 19 +- .../vulkan-shaders/topk_nary_search.comp | 99 +- ml/backend/ggml/ggml/src/ggml.c | 55 + ml/backend/ggml/ggml/src/gguf.cpp | 2 +- 115 files changed, 5176 insertions(+), 2585 deletions(-) rename llama/llama.cpp/src/models/{gemma3-iswa.cpp => gemma3.cpp} (78%) create mode 100644 llama/llama.cpp/src/models/mistral3.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/diag.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/diag.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/fill.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/fill.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/tri.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/tri.cuh diff --git a/Makefile.sync b/Makefile.sync index a485d6f2..2966d43f 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=7f8ef50cce40e3e7e4526a3696cb45658190e69a +FETCH_HEAD=17f7f4baad8b3a716ee139da7bb56ae984e8c0fa .PHONY: help help: diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 0122c7ed..5666fbc4 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "7f8ef50cce40e3e7e4526a3696cb45658190e69a"; +char const *LLAMA_COMMIT = "17f7f4baad8b3a716ee139da7bb56ae984e8c0fa"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 0d7fd9a9..0497f90a 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) - || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == ':' || c == '*' // Illegal characters || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { return false; } + if (!allow_subdirs && (c == '/' || c == '\\')) { + // Subdirectories not allowed, reject path separators + return false; + } } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -782,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) { #include +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + // returns true if successful, false otherwise bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); @@ -859,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) { #endif // _WIN32 } +bool fs_is_directory(const std::string & path) { + std::filesystem::path dir(path); + return std::filesystem::exists(dir) && std::filesystem::is_directory(dir); +} + std::string fs_get_cache_directory() { std::string cache_directory = ""; auto ensure_trailing_slash = [](std::string p) { @@ -893,6 +920,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif @@ -912,7 +941,7 @@ std::string fs_get_cache_file(const std::string & filename) { return cache_directory + filename; } -std::vector fs_list_files(const std::string & path) { +std::vector fs_list(const std::string & path, bool include_directories) { std::vector files; if (path.empty()) return files; @@ -927,14 +956,22 @@ std::vector fs_list_files(const std::string & path) { const auto & p = entry.path(); if (std::filesystem::is_regular_file(p)) { common_file_info info; - info.path = p.string(); - info.name = p.filename().string(); + info.path = p.string(); + info.name = p.filename().string(); + info.is_dir = false; try { info.size = static_cast(std::filesystem::file_size(p)); } catch (const std::filesystem::filesystem_error &) { info.size = 0; } files.push_back(std::move(info)); + } else if (include_directories && std::filesystem::is_directory(p)) { + common_file_info info; + info.path = p.string(); + info.name = p.filename().string(); + info.size = 0; // Directories have no size + info.is_dir = true; + files.push_back(std::move(info)); } } catch (const std::filesystem::filesystem_error &) { // skip entries we cannot inspect @@ -945,6 +982,32 @@ std::vector fs_list_files(const std::string & path) { return files; } +// +// TTY utils +// + +bool tty_can_use_colors() { + // Check NO_COLOR environment variable (https://no-color.org/) + if (const char * no_color = std::getenv("NO_COLOR")) { + if (no_color[0] != '\0') { + return false; + } + } + + // Check TERM environment variable + if (const char * term = std::getenv("TERM")) { + if (std::strcmp(term, "dumb") == 0) { + return false; + } + } + + // Check if stdout and stderr are connected to a terminal + // We check both because log messages can go to either + bool stdout_is_tty = isatty(fileno(stdout)); + bool stderr_is_tty = isatty(fileno(stderr)); + + return stdout_is_tty || stderr_is_tty; +} // // Model utils diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index 2f23d0ba..d28e4899 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -12,6 +12,10 @@ #include #include +#if defined(_WIN32) && !defined(_WIN32_WINNT) +#define _WIN32_WINNT 0x0A00 +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -26,8 +30,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_time_meas { common_time_meas(int64_t & t_acc, bool disable = false); ~common_time_meas(); @@ -223,6 +225,7 @@ struct common_params_model { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string docker_repo = ""; // Docker repo // NOLINT + std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; struct common_params_speculative { @@ -369,7 +372,7 @@ struct common_params { std::vector control_vectors; // control vector with user defined scale - int32_t verbosity = 0; + int32_t verbosity = 3; // LOG_LEVEL_INFO int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; @@ -478,9 +481,15 @@ struct common_params { bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + // router server configs + std::string models_dir = ""; // directory containing models for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server + bool log_json = false; std::string slot_save_path; + std::string media_path; // path to directory for loading media files float slot_prompt_similarity = 0.1f; @@ -631,8 +640,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat // Filesystem utils // -bool fs_validate_filename(const std::string & filename); +bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false); bool fs_create_directory_with_parents(const std::string & path); +bool fs_is_directory(const std::string & path); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); @@ -641,8 +651,16 @@ struct common_file_info { std::string path; std::string name; size_t size = 0; // in bytes + bool is_dir = false; }; -std::vector fs_list_files(const std::string & path); +std::vector fs_list(const std::string & path, bool include_directories); + +// +// TTY utils +// + +// Auto-detect if colors can be enabled based on terminal and environment +bool tty_can_use_colors(); // // Model utils diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index cb659915..6be55282 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -974,7 +974,7 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); + throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); diff --git a/llama/llama.cpp/common/log.cpp b/llama/llama.cpp/common/log.cpp index a24782b7..00a03f15 100644 --- a/llama/llama.cpp/common/log.cpp +++ b/llama/llama.cpp/common/log.cpp @@ -1,3 +1,4 @@ +#include "common.h" #include "log.h" #include @@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } -// Auto-detect if colors should be enabled based on terminal and environment -static bool common_log_should_use_colors_auto() { - // Check NO_COLOR environment variable (https://no-color.org/) - if (const char * no_color = std::getenv("NO_COLOR")) { - if (no_color[0] != '\0') { - return false; - } - } - - // Check TERM environment variable - if (const char * term = std::getenv("TERM")) { - if (std::strcmp(term, "dumb") == 0) { - return false; - } - } - - // Check if stdout and stderr are connected to a terminal - // We check both because log messages can go to either - bool stdout_is_tty = isatty(fileno(stdout)); - bool stderr_is_tty = isatty(fileno(stderr)); - - return stdout_is_tty || stderr_is_tty; -} - static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -391,7 +368,7 @@ struct common_log * common_log_main() { static std::once_flag init_flag; std::call_once(init_flag, [&]() { // Set default to auto-detect colors - log.set_colors(common_log_should_use_colors_auto()); + log.set_colors(tty_can_use_colors()); }); return &log; @@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) { void common_log_set_colors(struct common_log * log, log_colors colors) { if (colors == LOG_COLORS_AUTO) { - log->set_colors(common_log_should_use_colors_auto()); + log->set_colors(tty_can_use_colors()); return; } @@ -443,8 +420,22 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) { log->set_timestamps(timestamps); } +static int common_get_verbosity(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG; + case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO; + case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN; + case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR; + case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO + case GGML_LOG_LEVEL_NONE: + default: + return LOG_LEVEL_OUTPUT; + } +} + void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) { - if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + auto verbosity = common_get_verbosity(level); + if (verbosity <= common_log_verbosity_thold) { common_log_add(common_log_main(), level, "%s", text); } } diff --git a/llama/llama.cpp/common/log.h b/llama/llama.cpp/common/log.h index 7edb239a..b24f5f00 100644 --- a/llama/llama.cpp/common/log.h +++ b/llama/llama.cpp/common/log.h @@ -21,8 +21,14 @@ # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif -#define LOG_DEFAULT_DEBUG 1 -#define LOG_DEFAULT_LLAMA 0 +#define LOG_LEVEL_DEBUG 4 +#define LOG_LEVEL_INFO 3 +#define LOG_LEVEL_WARN 2 +#define LOG_LEVEL_ERROR 1 +#define LOG_LEVEL_OUTPUT 0 // output data from tools + +#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG +#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO enum log_colors { LOG_COLORS_AUTO = -1, @@ -67,10 +73,11 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch // 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU // 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU // -// I - info (stdout, V = 0) -// W - warning (stderr, V = 0) -// E - error (stderr, V = 0) // D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// I - info (stdout, V = LOG_DEFAULT_INFO) +// W - warning (stderr, V = LOG_DEFAULT_WARN) +// E - error (stderr, V = LOG_DEFAULT_ERROR) +// O - output (stdout, V = LOG_DEFAULT_OUTPUT) // void common_log_set_file (struct common_log * log, const char * file); // not thread-safe @@ -95,14 +102,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w } \ } while (0) -#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__) -#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) -#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__) -#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__) -#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__) -#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__) -#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO #define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) #define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b6bde25d..a5fe4f66 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -112,6 +112,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -205,6 +206,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, @@ -855,7 +857,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, @@ -2532,6 +2534,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2631,6 +2659,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_A_NOSCAN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 3936a468..ec9e3a6d 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -116,6 +116,7 @@ enum llm_arch { LLM_ARCH_COGVLM, LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; @@ -209,6 +210,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -379,6 +381,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN LLM_TENSOR_SSM_B_NORM, LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 1359c614..87f407f9 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -248,7 +248,10 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const size_t max_nodes = this->graph_max_nodes(); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); @@ -300,9 +303,6 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // avoid reserving graphs with zero outputs - assume one output per sequence n_outputs = n_seqs; @@ -1385,9 +1385,9 @@ void llama_context::output_reorder() { // graph // -uint32_t llama_context::graph_max_nodes() const { +uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { if (model.arch == LLM_ARCH_QWEN3NEXT) { - return std::max(8192u, 32u*model.n_tensors()); + return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index 20cbd789..cd26eafe 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -197,7 +197,7 @@ private: // public: - uint32_t graph_max_nodes() const; + uint32_t graph_max_nodes(uint32_t n_tokens) const; // can reuse the llm_graph_result instance of the context (for example to update a memory module) llm_graph_result * get_gf_res_reserve() const; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index a7307c47..a0299d18 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -181,6 +181,52 @@ static std::pair parse_char(const char * src) { throw std::runtime_error("unexpected end of input"); } +static std::pair parse_token(const llama_vocab * vocab, const char * src) { + const char * pos = src; + if (*pos != '<') { + throw std::runtime_error(std::string("expecting '<' at ") + pos); + } + pos++; + + // Parse <[id]> + if (*pos == '[') { + pos++; + const char * int_end = parse_int(pos); + uint32_t token_id = std::stoul(std::string(pos, int_end - pos)); + pos = int_end; + if (*pos != ']') { + throw std::runtime_error(std::string("expecting ']' at ") + pos); + } + pos++; + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + return std::make_pair(token_id, pos); + } + + if (vocab == nullptr) { + throw std::runtime_error(std::string("no vocab to parse token at ") + src); + } + + // Parse and tokenize to obtain the token id + while (*pos != 0 && *pos != '>') { + pos++; + } + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + + llama_token tokens[2]; + int32_t n_tokens = vocab->tokenize(src, static_cast(pos - src), tokens, 2, false, true); + if (n_tokens != 1) { + // must tokenize to exactly 1 token + throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'"); + } + return std::make_pair(tokens[0], pos); +} + static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); @@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break; + case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { print_grammar_char(file, elem.value); fprintf(file, "\") "); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } } fprintf(file, "\n"); @@ -284,6 +343,17 @@ static void print_rule( case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "."); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { @@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence( } } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '<' || *pos == '!') { // token + auto type = LLAMA_GRETYPE_TOKEN; + if (*pos == '!') { // token inverse + type = LLAMA_GRETYPE_TOKEN_NOT; + pos++; + } + auto token_pair = parse_token(vocab, pos); + const char * token_end = token_pair.second; + last_sym_start = rule.size(); + rule.push_back({type, token_pair.first}); + pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); @@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// returns true iff token matches the rule at pos (regular or inverse) +// asserts that pos is pointing to a token element +static bool llama_grammar_match_token( + const llama_grammar_element * pos, + const llama_token token) { + GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT); + if (pos->type == LLAMA_GRETYPE_TOKEN) { + return pos->value == static_cast(token); + } + if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return pos->value != static_cast(token); + } + return false; +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -738,6 +834,8 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +static void llama_grammar_accept_chr( + struct llama_grammar & grammar, + const llama_grammar_stack & stack, + uint32_t chr, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + return; + } + + const llama_grammar_element * pos = stack.back(); + + // ignore if this turns into a token + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return; + } + + auto match = llama_grammar_match_char(pos, chr); + if (match.first) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(match.second)) { + new_stack.push_back(match.second); + } + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + } +} + void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { llama_grammar_stacks stacks_new; stacks_new.reserve(grammar->stacks.size()); for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); - } + llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_element * stack_pos = stack.back(); + // if the top of the stack is a token rule, then we only need to check the token id + if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached the end of a token consumed by char rules, reject iff it ended + // in a partial response + if (tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } else if (!llama_grammar_match_token(stack_pos, tok.id)) { + rejects.push_back(tok); + } + } + return rejects; + } + llama_grammar_candidates next_candidates; next_candidates.reserve(candidates.size()); @@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back(tok); } } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id }); } else { rejects.push_back(tok); } @@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id }); } return rejects; @@ -974,12 +1100,13 @@ struct llama_grammar * llama_grammar_init_impl( ollama_vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy =*/ false, - /* .awaiting_trigger = */ false, - /* .trigger_buffer = */ "", - /* .trigger_tokens = */ {}, - /* .trigger_patterns = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy = */ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -993,7 +1120,7 @@ struct llama_grammar * llama_grammar_init_impl( size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens) { - llama_grammar_parser parser; + llama_grammar_parser parser(vocab); // if there is a grammar, parse it // rules will be empty (default) if there are parse errors @@ -1081,10 +1208,11 @@ struct llama_grammar * llama_grammar_init_impl( ollama_vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy = */ lazy, - /* .awaiting_trigger = */ lazy, - /* .trigger_buffer = */ "", + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; @@ -1108,6 +1236,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.lazy, grammar.awaiting_trigger, grammar.trigger_buffer, + grammar.trigger_buffer_positions, grammar.trigger_tokens, grammar.trigger_patterns, }; @@ -1164,7 +1293,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ cur_p->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id }); } } @@ -1184,10 +1313,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; std::smatch match; @@ -1205,10 +1336,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (start == std::string::npos) { start = match.position(0); } + + // replay tokens that overlap with [start, end) + for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { + auto [tok_start, tok_end] = tok_pos; + if (tok_end <= start) { + continue; + } + + size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces + size_t piece_len = tok_end - piece_start; + auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); + llama_grammar_accept_token(grammar, tok, tok_piece); + } + auto constrained_str = grammar.trigger_buffer.substr(start); - // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, constrained_str); + grammar.trigger_buffer_positions.clear(); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } @@ -1228,7 +1372,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty"); } - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); } void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { @@ -1246,6 +1390,61 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string } } +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar.stacks.size()); + + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { + stacks_new.emplace_back(surviving_stack); + } + } + } + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = decoded.second; + + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } +} + const std::string & ollama_vocab::token_to_piece(const uint32_t token) const { try { diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index 2a3a62db..5c0da404 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -47,11 +47,17 @@ enum llama_gretype { // any character (.) LLAMA_GRETYPE_CHAR_ANY = 7, + + // terminal element: token (<[token-id]>) + LLAMA_GRETYPE_TOKEN = 8, + + // inverse token (!<[token-id]>) + LLAMA_GRETYPE_TOKEN_NOT = 9, }; typedef struct llama_grammar_element { enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID + uint32_t value; // Unicode code point, rule ID, or token ID } llama_grammar_element; struct llama_partial_utf8 { @@ -63,6 +69,7 @@ struct llama_grammar_candidate { size_t index; const uint32_t * code_points; llama_partial_utf8 partial_utf8; + llama_token id; }; using llama_grammar_rule = std::vector< llama_grammar_element>; @@ -88,10 +95,13 @@ std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_candidates & candidates); struct llama_grammar_parser { + const llama_vocab * vocab; std::map symbol_ids; llama_grammar_rules rules; + llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} + llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char * src, size_t len); @@ -123,6 +133,9 @@ struct llama_grammar_trigger_pattern { }; struct llama_grammar { + // maintain a list of llama_tokens and their positions in the trigger_buffer + using token_pos = std::pair>; + // note: allow null vocab for testing (not great) const llama_vocab * vocab; const ollama_vocab * o_vocab; @@ -139,6 +152,7 @@ struct llama_grammar { bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated @@ -185,3 +199,8 @@ void llama_grammar_accept_impl( void llama_grammar_accept_str( struct llama_grammar & grammar, const std::string & piece); + +void llama_grammar_accept_token( + struct llama_grammar & grammar, + llama_token token, + const std::string & piece); diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 1d012e09..43620df7 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(f_attn_temp_scale != 0.0f); + GGML_ASSERT(n_attn_temp_floor_scale != 0); + std::vector attn_scale_data(n_tokens, 0.0f); for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; @@ -810,9 +813,6 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -973,7 +973,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // mask out the other groups selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens] - selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] + selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens] cb(selection_probs, "ffn_moe_probs_masked", il); } @@ -1093,9 +1093,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index 2ffe7dd3..a778fc3c 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -164,8 +164,8 @@ struct llama_hparams { // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; - uint32_t n_attn_temp_floor_scale = 8192; - float f_attn_temp_scale = 0.1; + uint32_t n_attn_temp_floor_scale = 0; + float f_attn_temp_scale = 0.0f; // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs diff --git a/llama/llama.cpp/src/llama-impl.h b/llama/llama.cpp/src/llama-impl.h index c5163e92..c3391e79 100644 --- a/llama/llama.cpp/src/llama-impl.h +++ b/llama/llama.cpp/src/llama-impl.h @@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * template struct no_init { T value; - no_init() { /* do nothing */ } + no_init() = default; }; struct time_meas { diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index 47497cf9..0641c2d2 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -485,7 +485,7 @@ struct llama_mlock::impl { if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { suggest = false; } - if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + if (suggest && ((uint64_t)lock_limit.rlim_max > (uint64_t)lock_limit.rlim_cur + size)) { suggest = false; } #endif diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 4468de2f..3c503b42 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -423,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s } struct llama_model::impl { - impl() {} - ~impl() {} + impl() = default; + ~impl() = default; uint64_t n_elements = 0; @@ -461,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() {} +llama_model::~llama_model() = default; void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -663,8 +663,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -1262,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; case 62: type = LLM_TYPE_27B; break; @@ -1597,8 +1606,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_20B; break; + switch (hparams.n_ff_exp) { + case 1408: type = LLM_TYPE_16B; break; + case 1792: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1626,6 +1636,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + // (optional) temperature tuning - used by mistral-large + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; @@ -2262,6 +2276,42 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MISTRAL3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f + // but may need further verification with other values + if (hparams.rope_yarn_log_mul != 0.0f) { + float factor = 1.0f / hparams.rope_freq_scale_train; + float mscale = 1.0f; + float mscale_all_dims = hparams.rope_yarn_log_mul; + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + } + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2575,6 +2625,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MISTRAL3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6530,7 +6581,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); @@ -7304,7 +7355,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_GEMMA3N: { @@ -7569,6 +7624,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MISTRAL3: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7738,6 +7797,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 0b23eaef..351dcb7b 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -666,7 +666,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::map mapped; int blk_id = 0; - int pruned_attention_w = 0; // make a list of weights std::vector tensors; @@ -674,11 +673,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (const auto & it : ml.weights_map) { const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id)); if (remapped_name.empty()) { - if (it.first.find("attn_v.weight") != std::string::npos || - it.first.find("attn_qkv.weight") != std::string::npos || - it.first.find("attn_kv_b.weight") != std::string::npos) { - pruned_attention_w++; - } LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); continue; } @@ -703,7 +697,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } - bool is_clip_model = false; for (const auto * it : tensors) { const struct ggml_tensor * tensor = it->tensor; @@ -717,32 +710,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; } - - is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix } qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0 && !is_clip_model) - { - const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); - // attention layers have a non-zero number of kv heads - int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); - if (llama_model_has_encoder(&model)) { - // now n_layer_attn is the number of attention layers in the encoder - // for each decoder block, there are 2 attention layers - n_layer_attn += 2 * model.hparams.dec_n_layer; - } - - // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers - const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); - - LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w); - - GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); - } - size_t total_size_org = 0; size_t total_size_new = 0; diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index ea450c36..f72f321b 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -3243,8 +3243,7 @@ void llama_vocab::impl::print_info() const { llama_vocab::llama_vocab() : pimpl(new impl(*this)) { } -llama_vocab::~llama_vocab() { -} +llama_vocab::~llama_vocab() = default; void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); diff --git a/llama/llama.cpp/src/models/deepseek2.cpp b/llama/llama.cpp/src/models/deepseek2.cpp index 0b41f7ba..dbaa8297 100644 --- a/llama/llama.cpp/src/models/deepseek2.cpp +++ b/llama/llama.cpp/src/models/deepseek2.cpp @@ -30,6 +30,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); + // (optional) temperature tuning - used by mistral-large + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -128,6 +134,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn, model.layers[il].wo, NULL, @@ -160,6 +172,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); cb(Kcur, "Kcur", il); + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn, model.layers[il].wo, NULL, diff --git a/llama/llama.cpp/src/models/gemma3-iswa.cpp b/llama/llama.cpp/src/models/gemma3.cpp similarity index 78% rename from llama/llama.cpp/src/models/gemma3-iswa.cpp rename to llama/llama.cpp/src/models/gemma3.cpp index 839ff6d3..ae60ef47 100644 --- a/llama/llama.cpp/src/models/gemma3-iswa.cpp +++ b/llama/llama.cpp/src/models/gemma3.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); + cb(cur, "ffn_post_norm", il); cur = ggml_add(ctx0, cur, sa_out); @@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll // lm_head cur = build_lora_mm(model.output, cur); + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } + +template struct llm_build_gemma3; +template struct llm_build_gemma3; diff --git a/llama/llama.cpp/src/models/mistral3.cpp b/llama/llama.cpp/src/models/mistral3.cpp new file mode 100644 index 00000000..0b672235 --- /dev/null +++ b/llama/llama.cpp/src/models/mistral3.cpp @@ -0,0 +1,160 @@ +#include "models.h" + +llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // (optional) temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/llama/llama.cpp/src/models/models.h b/llama/llama.cpp/src/models/models.h index 71fea796..e0aec822 100644 --- a/llama/llama.cpp/src/models/models.h +++ b/llama/llama.cpp/src/models/models.h @@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context { llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_gemma3 : public llm_graph_context { + llm_build_gemma3(const llama_model & model, const llm_graph_params & params); }; struct llm_build_gemma3n_iswa : public llm_graph_context { @@ -322,6 +323,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mistral3 : public llm_graph_context { + llm_build_mistral3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp index 040518e1..13ced055 100644 --- a/llama/llama.cpp/src/unicode.cpp +++ b/llama/llama.cpp/src/unicode.cpp @@ -520,7 +520,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & // use std::wregex to split the text static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { - std::wregex expr(regex_expr); + std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; @@ -550,7 +550,7 @@ static std::vector unicode_regex_split_stl(const std::wstring & wtext, c // use std::regex to split the text static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { - std::regex expr(regex_expr); + std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp index 3334ff25..2a325c72 100644 --- a/llama/llama.cpp/tools/mtmd/clip.cpp +++ b/llama/llama.cpp/tools/mtmd/clip.cpp @@ -441,6 +441,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool is_allocated = false; // for debugging bool debug_graph = false; @@ -2033,7 +2034,7 @@ private: ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -2812,7 +2813,8 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - hparams.set_limit_image_tokens(64, 256); + // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 + hparams.set_limit_image_tokens(64, 1024); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -3347,12 +3349,30 @@ struct clip_model_loader { }; static void warmup(clip_ctx & ctx_clip) { + // create a fake batch + const auto & hparams = ctx_clip.model.hparams; + clip_image_f32_batch batch; + clip_image_f32_ptr img(clip_image_f32_init()); + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { + img->nx = hparams.warmup_image_size; + img->ny = hparams.warmup_image_size; + LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); + } else { + img->nx = hparams.warmup_audio_size; + img->ny = hparams.n_mel_bins; + LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); + } + batch.entries.push_back(std::move(img)); + warmup(ctx_clip, batch); + } + + static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { support_info_graph info; if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { // try to enable flash attention to see if it's supported ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && info.fattn_op) { auto op = info.fattn_op; LOG_WRN("%s: *****************************************************************\n", __func__); @@ -3371,15 +3391,17 @@ struct clip_model_loader { LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__); ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; - alloc_compute_meta(ctx_clip); + alloc_compute_meta(ctx_clip, batch); } } else { - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); } } + ctx_clip.is_allocated = true; // mark buffers as allocated + LOG_INF("%s: flash attention is %s\n", __func__, (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); @@ -3411,24 +3433,9 @@ struct clip_model_loader { } } - static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { - const auto & hparams = ctx_clip.model.hparams; + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); - // create a fake batch - clip_image_f32_batch batch; - clip_image_f32_ptr img(clip_image_f32_init()); - if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { - img->nx = hparams.warmup_image_size; - img->ny = hparams.warmup_image_size; - LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); - } else { - img->nx = hparams.warmup_audio_size; - img->ny = hparams.n_mel_bins; - LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); - } - batch.entries.push_back(std::move(img)); - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); @@ -3568,14 +3575,18 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.warmup(*ctx_vision); + if (ctx_params.warmup) { + loader.warmup(*ctx_vision); + } } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.warmup(*ctx_audio); + if (ctx_params.warmup) { + loader.warmup(*ctx_audio); + } } } catch (const std::exception & e) { @@ -3788,12 +3799,13 @@ struct img_tool { const int width = inp_size.width; const int height = inp_size.height; + auto round_by_factor = [f = align_size](float x) { return static_cast(std::round(x / static_cast(f))) * f; }; auto ceil_by_factor = [f = align_size](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; auto floor_by_factor = [f = align_size](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; // always align up first - int h_bar = std::max(align_size, ceil_by_factor(height)); - int w_bar = std::max(align_size, ceil_by_factor(width)); + int h_bar = std::max(align_size, round_by_factor(height)); + int w_bar = std::max(align_size, round_by_factor(width)); if (h_bar * w_bar > max_pixels) { const auto beta = std::sqrt(static_cast(height * width) / max_pixels); @@ -4408,7 +4420,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -4666,6 +4679,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } + // if buffers are not allocated, we need to do a warmup run to allocate them + if (!ctx->is_allocated) { + clip_model_loader::warmup(*ctx, *imgs_c_ptr); + } + // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); diff --git a/llama/llama.cpp/tools/mtmd/clip.h b/llama/llama.cpp/tools/mtmd/clip.h index c1442afe..e8aeb206 100644 --- a/llama/llama.cpp/tools/mtmd/clip.h +++ b/llama/llama.cpp/tools/mtmd/clip.h @@ -34,6 +34,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + bool warmup; }; struct clip_init_result { diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp index 9858de63..0f5712e2 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp @@ -118,6 +118,7 @@ mtmd_context_params mtmd_context_params_default() { /* image_marker */ MTMD_DEFAULT_IMAGE_MARKER, /* media_marker */ mtmd_default_marker(), /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, + /* warmup */ true, /* image_min_tokens */ -1, /* image_max_tokens */ -1, }; @@ -187,6 +188,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* warmup */ ctx_params.warmup, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -314,6 +316,10 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_LFM2) { + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + } } diff --git a/llama/llama.cpp/tools/mtmd/mtmd.h b/llama/llama.cpp/tools/mtmd/mtmd.h index 8d3fa5d3..a6a1af3b 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.h +++ b/llama/llama.cpp/tools/mtmd/mtmd.h @@ -85,6 +85,7 @@ struct mtmd_context_params { const char * image_marker; // deprecated, use media_marker instead const char * media_marker; enum llama_flash_attn_type flash_attn_type; + bool warmup; // whether to run a warmup encode pass after initialization // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch index 4a2ee02f..7a91351e 100644 --- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch +++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch @@ -23,7 +23,7 @@ problem. 8 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 4cf377e7f..4882541c8 100644 +index 08681f35e..afde2f0b7 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { @@ -42,7 +42,7 @@ index 4cf377e7f..4882541c8 100644 } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { -@@ -2079,6 +2079,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -2106,6 +2106,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); @@ -54,7 +54,7 @@ index 4cf377e7f..4882541c8 100644 } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { -@@ -2131,7 +2136,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { +@@ -2158,7 +2163,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { }; static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { @@ -64,7 +64,7 @@ index 4cf377e7f..4882541c8 100644 /* .init_tensor = */ NULL, // no initialization required /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp -index df28d67fb..1f6a56ba2 100644 +index 81288464c..866758782 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -831,6 +831,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { @@ -84,10 +84,10 @@ index df28d67fb..1f6a56ba2 100644 /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index fa7e1e13a..8f3b1c173 100644 +index 279679a4e..5145c1e88 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -579,6 +579,7 @@ struct ggml_backend_cuda_buffer_context { +@@ -583,6 +583,7 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; delete ctx; @@ -95,7 +95,7 @@ index fa7e1e13a..8f3b1c173 100644 } static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) { -@@ -834,6 +835,7 @@ struct ggml_backend_cuda_split_buffer_context { +@@ -838,6 +839,7 @@ struct ggml_backend_cuda_split_buffer_context { static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; delete ctx; @@ -103,7 +103,7 @@ index fa7e1e13a..8f3b1c173 100644 } static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -1115,6 +1117,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { +@@ -1119,6 +1121,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { CUDA_CHECK(cudaFreeHost(buffer->context)); @@ -132,10 +132,10 @@ index 70bf6f3d9..f2b7fe692 100644 static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp -index e5302f455..43fa83e8f 100644 +index 0d37587f6..ff373d413 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp -@@ -3412,6 +3412,7 @@ struct ggml_backend_opencl_buffer_context { +@@ -3417,6 +3417,7 @@ struct ggml_backend_opencl_buffer_context { static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; delete ctx; @@ -144,10 +144,10 @@ index e5302f455..43fa83e8f 100644 static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp -index 48fd99a76..da2aab3df 100644 +index 18a45d2d9..89041805e 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp -@@ -555,6 +555,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -556,6 +556,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); RPC_STATUS_ASSERT(status); delete ctx; @@ -156,7 +156,7 @@ index 48fd99a76..da2aab3df 100644 static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp -index 3f1bdfb9f..a95c2f305 100644 +index 7449a9160..e69a1ff5f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -355,6 +355,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { @@ -184,10 +184,10 @@ index 3f1bdfb9f..a95c2f305 100644 static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 66dd0bfab..83cdec29e 100644 +index c6f5809cc..c801d2fd2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -12368,6 +12368,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -12271,6 +12271,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; @@ -195,7 +195,7 @@ index 66dd0bfab..83cdec29e 100644 } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -12511,6 +12512,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe +@@ -12414,6 +12415,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch index 096d5f4e..7bb5f48a 100644 --- a/llama/patches/0002-pretokenizer.patch +++ b/llama/patches/0002-pretokenizer.patch @@ -10,7 +10,7 @@ logs instead of throwing an error 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index a73c4c448..b9f0631f4 100644 +index e2cca66e4..8246a0a14 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1825,16 +1825,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch index 1f83a77e..d05b01d9 100644 --- a/llama/patches/0003-clip-unicode.patch +++ b/llama/patches/0003-clip-unicode.patch @@ -10,7 +10,7 @@ filesystems for paths that include wide characters 1 file changed, 39 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 05777d2d9..f4c4d2c48 100644 +index 3ed08a0fe..6be1470ad 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -24,6 +24,19 @@ @@ -33,7 +33,7 @@ index 05777d2d9..f4c4d2c48 100644 struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL}; enum ffn_op_type { -@@ -3255,7 +3268,29 @@ struct clip_model_loader { +@@ -3257,7 +3270,29 @@ struct clip_model_loader { { std::vector read_buf; @@ -63,7 +63,7 @@ index 05777d2d9..f4c4d2c48 100644 if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } -@@ -3282,7 +3317,11 @@ struct clip_model_loader { +@@ -3284,7 +3319,11 @@ struct clip_model_loader { ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } } diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch index 82241b87..7adce420 100644 --- a/llama/patches/0004-solar-pro.patch +++ b/llama/patches/0004-solar-pro.patch @@ -19,7 +19,7 @@ adds support for the Solar Pro architecture create mode 100644 src/models/solar.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt -index 67c7807e0..fda881640 100644 +index 4192af7c0..bd44d73e7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -125,6 +125,7 @@ add_library(llama @@ -31,7 +31,7 @@ index 67c7807e0..fda881640 100644 models/starcoder.cpp models/starcoder2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index 8571a2e02..b6bde25d5 100644 +index 64ad1b776..a5fe4f66c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -85,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { @@ -42,15 +42,15 @@ index 8571a2e02..b6bde25d5 100644 { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, -@@ -204,6 +205,7 @@ static const std::map LLM_KV_NAMES = { - { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, +@@ -206,6 +207,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, -@@ -2023,6 +2025,24 @@ static const std::map> LLM_TENSOR_N +@@ -2025,6 +2027,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, @@ -75,7 +75,7 @@ index 8571a2e02..b6bde25d5 100644 { LLM_ARCH_WAVTOKENIZER_DEC, { -@@ -2681,6 +2701,7 @@ static const std::map LLM_TENSOR_INFOS = { +@@ -2710,6 +2730,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, @@ -84,7 +84,7 @@ index 8571a2e02..b6bde25d5 100644 {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h -index 150646478..3936a4687 100644 +index e11318002..ec9e3a6df 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -89,6 +89,7 @@ enum llm_arch { @@ -95,15 +95,15 @@ index 150646478..3936a4687 100644 LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, -@@ -208,6 +209,7 @@ enum llm_kv { - LLM_KV_ATTENTION_SCALE, +@@ -210,6 +211,7 @@ enum llm_kv { LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_TEMPERATURE_SCALE, + LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, -@@ -459,6 +461,7 @@ enum llm_tensor { +@@ -462,6 +464,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, @@ -131,7 +131,7 @@ index 8cdbaf69f..41127bf91 100644 if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h -index c3a53be79..2ffe7dd30 100644 +index 6eff334a5..a778fc3cf 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -64,6 +64,8 @@ struct llama_hparams { @@ -167,10 +167,10 @@ index aa3a65f87..ee303bd58 100644 llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index c2a545531..4468de2f9 100644 +index 04fccc979..3c503b424 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp -@@ -1961,6 +1961,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { +@@ -1975,6 +1975,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; @@ -192,7 +192,7 @@ index c2a545531..4468de2f9 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -5350,6 +5365,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -5401,6 +5416,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -227,7 +227,7 @@ index c2a545531..4468de2f9 100644 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -@@ -7425,6 +7468,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { +@@ -7480,6 +7523,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; @@ -238,7 +238,7 @@ index c2a545531..4468de2f9 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { llm = std::make_unique(*this, params); -@@ -7684,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { +@@ -7743,6 +7790,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: @@ -268,10 +268,10 @@ index f8342cf2c..cbf4e1bfa 100644 struct llama_layer_convnext convnext; diff --git a/src/models/models.h b/src/models/models.h -index 7ba225b47..71fea796d 100644 +index 6494f5450..e0aec822c 100644 --- a/src/models/models.h +++ b/src/models/models.h -@@ -510,6 +510,11 @@ struct llm_build_smollm3 : public llm_graph_context { +@@ -515,6 +515,11 @@ struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params); }; diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch index 0cebdb58..22f3cd9f 100644 --- a/llama/patches/0005-fix-deepseek-deseret-regex.patch +++ b/llama/patches/0005-fix-deepseek-deseret-regex.patch @@ -12,7 +12,7 @@ regex 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index b9f0631f4..1525283d7 100644 +index 8246a0a14..dfba7778b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { @@ -25,7 +25,7 @@ index b9f0631f4..1525283d7 100644 "\\s+$", "[一-龥ࠀ-一가-퟿]+", diff --git a/src/unicode.cpp b/src/unicode.cpp -index 77ba4fc46..040518e1e 100644 +index bb44edfad..13ced055f 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -2,6 +2,11 @@ diff --git a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch index 66ac01c1..83061168 100644 --- a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch +++ b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch @@ -8,7 +8,7 @@ Subject: [PATCH] maintain ordering for rules for grammar 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp -index c8421e1e8..cb659915d 100644 +index c3b4e5d9d..6be552826 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -310,7 +310,7 @@ private: diff --git a/llama/patches/0007-sort-devices-by-score.patch b/llama/patches/0007-sort-devices-by-score.patch index 6bf45ae5..f45da396 100644 --- a/llama/patches/0007-sort-devices-by-score.patch +++ b/llama/patches/0007-sort-devices-by-score.patch @@ -11,10 +11,10 @@ with the fastest acceleration is loaded 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index e96b5c403..a55d9b280 100644 +index 4181a714a..079dba211 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -179,7 +179,7 @@ struct ggml_backend_reg_entry { +@@ -183,7 +183,7 @@ struct ggml_backend_reg_entry { struct ggml_backend_registry { std::vector backends; @@ -23,7 +23,7 @@ index e96b5c403..a55d9b280 100644 ggml_backend_registry() { #ifdef GGML_USE_CUDA -@@ -230,7 +230,7 @@ struct ggml_backend_registry { +@@ -237,7 +237,7 @@ struct ggml_backend_registry { } } @@ -32,7 +32,7 @@ index e96b5c403..a55d9b280 100644 if (!reg) { return; } -@@ -241,15 +241,20 @@ struct ggml_backend_registry { +@@ -248,15 +248,20 @@ struct ggml_backend_registry { #endif backends.push_back({ reg, std::move(handle) }); for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { @@ -56,7 +56,7 @@ index e96b5c403..a55d9b280 100644 } ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { -@@ -293,7 +298,7 @@ struct ggml_backend_registry { +@@ -300,7 +305,7 @@ struct ggml_backend_registry { GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); @@ -65,7 +65,7 @@ index e96b5c403..a55d9b280 100644 return reg; } -@@ -316,7 +321,7 @@ struct ggml_backend_registry { +@@ -323,7 +328,7 @@ struct ggml_backend_registry { // remove devices devices.erase( std::remove_if(devices.begin(), devices.end(), @@ -74,7 +74,7 @@ index e96b5c403..a55d9b280 100644 devices.end()); // remove backend -@@ -374,7 +379,7 @@ size_t ggml_backend_dev_count() { +@@ -381,7 +386,7 @@ size_t ggml_backend_dev_count() { ggml_backend_dev_t ggml_backend_dev_get(size_t index) { GGML_ASSERT(index < ggml_backend_dev_count()); diff --git a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch index d3ab6500..315613e0 100644 --- a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch +++ b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch @@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants 1 file changed, 2 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index d93664b8b..800f98b65 100644 +index 4c04c3300..f4747f262 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -349,6 +349,7 @@ function(ggml_add_cpu_backend_variant tag_name) +@@ -345,6 +345,7 @@ function(ggml_add_cpu_backend_variant tag_name) endif() ggml_add_cpu_backend_variant_impl(${tag_name}) @@ -19,7 +19,7 @@ index d93664b8b..800f98b65 100644 endfunction() ggml_add_backend(CPU) -@@ -359,6 +360,7 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -355,6 +356,7 @@ if (GGML_CPU_ALL_VARIANTS) elseif (GGML_CPU_ARM_ARCH) message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") endif() diff --git a/llama/patches/0009-remove-amx.patch b/llama/patches/0009-remove-amx.patch index bfb3727a..cace86f9 100644 --- a/llama/patches/0009-remove-amx.patch +++ b/llama/patches/0009-remove-amx.patch @@ -9,10 +9,10 @@ disable amx as it reduces performance on some systems 1 file changed, 4 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 800f98b65..6d493a4ff 100644 +index f4747f262..d55aed348 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -369,10 +369,6 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -365,10 +365,6 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0010-fix-string-arr-kv-loading.patch index ce151948..622783d9 100644 --- a/llama/patches/0010-fix-string-arr-kv-loading.patch +++ b/llama/patches/0010-fix-string-arr-kv-loading.patch @@ -25,7 +25,7 @@ index 79ee20206..3efb22f01 100644 // get ith C string from array with given key_id GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp -index 8cc4ef1cf..d950dbdf5 100644 +index b165d8bdc..f91d4faba 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id @@ -53,7 +53,7 @@ index 8cc4ef1cf..d950dbdf5 100644 } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 1525283d7..ea450c361 100644 +index dfba7778b..f72f321b9 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1781,9 +1781,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0011-ollama-debug-tensor.patch index 76db920f..8680c91d 100644 --- a/llama/patches/0011-ollama-debug-tensor.patch +++ b/llama/patches/0011-ollama-debug-tensor.patch @@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index 3247af8bb..5be08d6f4 100644 +index b468b115a..bb65985b4 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -15,6 +15,8 @@ @@ -20,7 +20,7 @@ index 3247af8bb..5be08d6f4 100644 #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -@@ -2922,6 +2924,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { +@@ -2928,6 +2930,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); diff --git a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch index e84bc875..f26e1bc2 100644 --- a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch +++ b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch @@ -4,16 +4,16 @@ Date: Mon, 21 Apr 2025 13:30:31 -0700 Subject: [PATCH] add ollama vocab for grammar support --- - src/llama-grammar.cpp | 49 ++++++++++++++++++++++++++++++++++++------ + src/llama-grammar.cpp | 48 ++++++++++++++++++++++++++++++++++++------ src/llama-grammar.h | 14 ++++++++++++ src/llama-sampling.cpp | 6 +++--- - 3 files changed, 59 insertions(+), 10 deletions(-) + 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp -index b3c5eb571..a7307c47f 100644 +index 75d5d750c..a0299d181 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp -@@ -915,6 +915,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( +@@ -1041,6 +1041,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, @@ -21,15 +21,15 @@ index b3c5eb571..a7307c47f 100644 const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { -@@ -970,6 +971,7 @@ struct llama_grammar * llama_grammar_init_impl( +@@ -1096,6 +1097,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, + ollama_vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, -@@ -983,6 +985,7 @@ struct llama_grammar * llama_grammar_init_impl( + /* .partial_utf8 = */ {}, +@@ -1110,6 +1112,7 @@ struct llama_grammar * llama_grammar_init_impl( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, @@ -37,15 +37,15 @@ index b3c5eb571..a7307c47f 100644 const char * grammar_str, const char * grammar_root, bool lazy, -@@ -1075,6 +1078,7 @@ struct llama_grammar * llama_grammar_init_impl( +@@ -1202,6 +1205,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, + ollama_vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, -@@ -1097,6 +1101,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { + /* .partial_utf8 = */ {}, +@@ -1225,6 +1229,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { auto * result = new llama_grammar { grammar.vocab, @@ -53,7 +53,7 @@ index b3c5eb571..a7307c47f 100644 grammar.rules, grammar.stacks, grammar.partial_utf8, -@@ -1124,7 +1129,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra +@@ -1253,7 +1258,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra } void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { @@ -61,7 +61,7 @@ index b3c5eb571..a7307c47f 100644 if (grammar.awaiting_trigger) { return; -@@ -1146,9 +1150,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ +@@ -1275,9 +1279,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; @@ -77,7 +77,7 @@ index b3c5eb571..a7307c47f 100644 if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } -@@ -1167,9 +1175,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ +@@ -1296,9 +1304,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { @@ -90,7 +90,7 @@ index b3c5eb571..a7307c47f 100644 if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { -@@ -1209,13 +1218,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token +@@ -1353,13 +1362,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } } @@ -106,12 +106,11 @@ index b3c5eb571..a7307c47f 100644 + GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty"); } - llama_grammar_accept_str(grammar, piece); -@@ -1235,3 +1245,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string - throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + llama_grammar_accept_token(grammar, token, piece); +@@ -1435,3 +1445,27 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke } } -+ + + +const std::string & ollama_vocab::token_to_piece(const uint32_t token) const { + try { @@ -137,7 +136,7 @@ index b3c5eb571..a7307c47f 100644 + } +} diff --git a/src/llama-grammar.h b/src/llama-grammar.h -index f8c291de9..2a3a62db3 100644 +index a4c978ac1..5c0da4049 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -6,8 +6,19 @@ @@ -160,15 +159,15 @@ index f8c291de9..2a3a62db3 100644 // grammar element type enum llama_gretype { -@@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern { - struct llama_grammar { +@@ -127,6 +138,7 @@ struct llama_grammar { + // note: allow null vocab for testing (not great) const llama_vocab * vocab; + const ollama_vocab * o_vocab; const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; -@@ -141,12 +153,14 @@ struct llama_grammar { +@@ -155,12 +167,14 @@ struct llama_grammar { // note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index 26c6dca7..a022e33e 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -12,10 +12,10 @@ Subject: [PATCH] add argsort and cuda copy for i32 5 files changed, 414 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 2745fc54e..40666bab6 100644 +index 303278397..7d1733adb 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp -@@ -7846,6 +7846,45 @@ static void ggml_compute_forward_argsort_f32( +@@ -7932,6 +7932,45 @@ static void ggml_compute_forward_argsort_f32( } } @@ -61,7 +61,7 @@ index 2745fc54e..40666bab6 100644 void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { -@@ -7857,6 +7896,10 @@ void ggml_compute_forward_argsort( +@@ -7943,6 +7982,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; @@ -292,10 +292,10 @@ index c4ceb4fc5..0e53ecc39 100644 if (can_be_transposed) { ggml_cpy_scalar_cuda diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 73b45c762..8a6c834d1 100644 +index 51bcbae30..236838e9e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -4721,8 +4721,77 @@ kernel void kernel_argsort_f32_i32( +@@ -4954,8 +4954,77 @@ kernel void kernel_argsort_f32_i32( } } @@ -373,7 +373,7 @@ index 73b45c762..8a6c834d1 100644 typedef void (argsort_merge_t)( constant ggml_metal_kargs_argsort_merge & args, -@@ -4877,8 +4946,154 @@ kernel void kernel_argsort_merge_f32_i32( +@@ -5110,8 +5179,154 @@ kernel void kernel_argsort_merge_f32_i32( } } diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index fdb462c9..aa466862 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -35,10 +35,10 @@ index f1b740785..c54ff98bf 100644 GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c -index 218222ece..06ee502ab 100644 +index a5995fdc2..dbfd8b5b2 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c -@@ -493,6 +493,7 @@ struct node_alloc { +@@ -494,6 +494,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] struct vbuffer ** buffers; // [n_buffers] @@ -46,7 +46,7 @@ index 218222ece..06ee502ab 100644 struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; -@@ -516,6 +517,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs +@@ -517,6 +518,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *)); GGML_ASSERT(galloc->buffers != NULL); @@ -56,7 +56,7 @@ index 218222ece..06ee502ab 100644 galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); -@@ -583,6 +587,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { +@@ -584,6 +588,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); @@ -64,7 +64,7 @@ index 218222ece..06ee502ab 100644 free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); -@@ -898,6 +903,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -899,6 +904,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } @@ -73,7 +73,7 @@ index 218222ece..06ee502ab 100644 // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer -@@ -932,14 +939,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -933,14 +940,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c #endif ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); @@ -96,7 +96,7 @@ index 218222ece..06ee502ab 100644 } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { -@@ -1094,6 +1106,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { +@@ -1095,6 +1107,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_vbuffer_size(galloc->buffers[buffer_id]); } @@ -120,10 +120,10 @@ index 218222ece..06ee502ab 100644 static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 4882541c8..ff41c7712 100644 +index afde2f0b7..dbf8486a0 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -1813,6 +1813,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe +@@ -1840,6 +1840,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch index c3836536..1ae032ab 100644 --- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch +++ b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch @@ -22,10 +22,10 @@ index c54ff98bf..229bf387b 100644 size_t memory_total; // device type diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 8f3b1c173..e803f4af6 100644 +index 5145c1e88..f641c1016 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -185,6 +185,51 @@ static int ggml_cuda_parse_id(char devName[]) { +@@ -189,6 +189,51 @@ static int ggml_cuda_parse_id(char devName[]) { } #endif // defined(GGML_USE_HIP) @@ -77,7 +77,7 @@ index 8f3b1c173..e803f4af6 100644 static ggml_cuda_device_info ggml_cuda_init() { ggml_cuda_device_info info = {}; -@@ -251,22 +296,24 @@ static ggml_cuda_device_info ggml_cuda_init() { +@@ -255,22 +300,24 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } @@ -108,7 +108,7 @@ index 8f3b1c173..e803f4af6 100644 std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); -@@ -4048,6 +4095,7 @@ struct ggml_backend_cuda_device_context { +@@ -4110,6 +4157,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string pci_bus_id; @@ -116,7 +116,7 @@ index 8f3b1c173..e803f4af6 100644 }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -4136,6 +4184,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k +@@ -4198,6 +4246,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k } #endif // defined(__linux__) @@ -128,7 +128,7 @@ index 8f3b1c173..e803f4af6 100644 static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); -@@ -4176,6 +4229,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4238,6 +4291,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); @@ -136,7 +136,7 @@ index 8f3b1c173..e803f4af6 100644 props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); -@@ -4767,6 +4821,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4833,6 +4887,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch index fa371e8e..19c4d25d 100644 --- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch +++ b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch @@ -10,7 +10,7 @@ Signed-off-by: Gabe Goodhart 2 files changed, 13 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp -index dfad9cd79..9858de630 100644 +index d06fa42e6..0f5712e21 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -87,6 +87,16 @@ enum mtmd_slice_tmpl { @@ -31,7 +31,7 @@ index dfad9cd79..9858de630 100644 return "<__media__>"; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h -index 015119be8..8d3fa5d34 100644 +index b3df24c29..a6a1af3b8 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch index 549e48fa..a788a562 100644 --- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch +++ b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch @@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index 5be08d6f4..7a0df30c3 100644 +index bb65985b4..47089a62e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -2463,7 +2463,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { +@@ -2464,7 +2464,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { // Newer Windows 11 versions aggresively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 diff --git a/llama/patches/0018-ggml-Add-batch-size-hint.patch b/llama/patches/0018-ggml-Add-batch-size-hint.patch index f917f397..cef00be5 100644 --- a/llama/patches/0018-ggml-Add-batch-size-hint.patch +++ b/llama/patches/0018-ggml-Add-batch-size-hint.patch @@ -58,7 +58,7 @@ index 6792ba986..0f5b03cef 100644 // (optional) event synchronization // record an event on this stream diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index ff41c7712..f511e8d76 100644 +index dbf8486a0..312ca873c 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -348,14 +348,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba @@ -86,9 +86,9 @@ index ff41c7712..f511e8d76 100644 + int batch_size; // a hint on the batch size to optimize processing, -1 to use heuristics + int debug; - }; -@@ -814,7 +816,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st + // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC] +@@ -820,7 +822,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op @@ -97,7 +97,7 @@ index ff41c7712..f511e8d76 100644 for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); -@@ -1556,7 +1558,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s +@@ -1572,7 +1574,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } if (!sched->callback_eval) { @@ -106,7 +106,7 @@ index ff41c7712..f511e8d76 100644 if (ec != GGML_STATUS_SUCCESS) { return ec; } -@@ -1578,7 +1580,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s +@@ -1594,7 +1596,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); @@ -115,7 +115,7 @@ index ff41c7712..f511e8d76 100644 if (ec != GGML_STATUS_SUCCESS) { return ec; } -@@ -1657,6 +1659,7 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1684,6 +1686,7 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->op_offload = op_offload; @@ -123,7 +123,7 @@ index ff41c7712..f511e8d76 100644 ggml_backend_sched_reset(sched); -@@ -1688,6 +1691,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { +@@ -1715,6 +1718,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { free(sched); } @@ -178,10 +178,10 @@ index 3191faaa4..32f14c811 100644 static const struct ggml_backend_i ggml_backend_cpu_i = { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index e803f4af6..78fb2d8b3 100644 +index f641c1016..17062697b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -2885,7 +2885,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { +@@ -2901,7 +2901,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #ifdef USE_CUDA_GRAPH static bool check_node_graph_compatibility(ggml_cgraph * cgraph, @@ -190,7 +190,7 @@ index e803f4af6..78fb2d8b3 100644 // Loop over nodes in GGML graph to obtain info needed for CUDA graph -@@ -2918,24 +2918,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, +@@ -2934,24 +2934,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, #endif } @@ -241,7 +241,7 @@ index e803f4af6..78fb2d8b3 100644 } if (!use_cuda_graph) { -@@ -3679,7 +3689,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3742,7 +3752,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } @@ -250,7 +250,7 @@ index e803f4af6..78fb2d8b3 100644 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); -@@ -3717,7 +3727,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, +@@ -3780,7 +3790,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); @@ -278,10 +278,10 @@ index 8fc1c2fb5..ba95b4acc 100644 static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 83cdec29e..a36c6560c 100644 +index c801d2fd2..b2c0d0cee 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -13103,7 +13103,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru +@@ -13006,7 +13006,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } @@ -290,7 +290,7 @@ index 83cdec29e..a36c6560c 100644 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -@@ -13320,6 +13320,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg +@@ -13241,6 +13241,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg return GGML_STATUS_SUCCESS; UNUSED(backend); diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0020-ggml-No-alloc-mode.patch index 01a42690..95962f82 100644 --- a/llama/patches/0020-ggml-No-alloc-mode.patch +++ b/llama/patches/0020-ggml-No-alloc-mode.patch @@ -75,7 +75,7 @@ index 0f5b03cef..7bdf9d81f 100644 struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index f511e8d76..74b7f070c 100644 +index 312ca873c..4092dfe8a 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t @@ -121,10 +121,10 @@ index f511e8d76..74b7f070c 100644 void * base = buffer->iface.get_base(buffer); GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); -@@ -725,6 +745,12 @@ struct ggml_backend_sched { - int batch_size; // a hint on the batch size to optimize processing, -1 to use heuristics - - int debug; +@@ -731,6 +751,12 @@ struct ggml_backend_sched { + int debug_realloc; + int debug_graph_size; + int debug_prev_graph_size; + + // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation + // if false, dummy buffers are used for faster memory sizing calculations @@ -134,7 +134,7 @@ index f511e8d76..74b7f070c 100644 }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) -@@ -1614,6 +1640,17 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1630,6 +1656,17 @@ ggml_backend_sched_t ggml_backend_sched_new( size_t graph_size, bool parallel, bool op_offload) { @@ -152,7 +152,7 @@ index f511e8d76..74b7f070c 100644 GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); -@@ -1655,11 +1692,14 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1682,11 +1719,14 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->events[b][c] = ggml_backend_event_new(backends[b]->device); } } @@ -167,7 +167,7 @@ index f511e8d76..74b7f070c 100644 ggml_backend_sched_reset(sched); -@@ -1674,6 +1714,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { +@@ -1701,6 +1741,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { for (int c = 0; c < sched->n_copies; c++) { ggml_backend_event_free(sched->events[b][c]); } @@ -178,7 +178,7 @@ index f511e8d76..74b7f070c 100644 } ggml_gallocr_free(sched->galloc); ggml_free(sched->ctx); -@@ -1719,6 +1763,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * +@@ -1746,6 +1790,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * return false; } @@ -203,7 +203,7 @@ index f511e8d76..74b7f070c 100644 ggml_backend_sched_reset(sched); return true; -@@ -1824,7 +1886,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, +@@ -1851,7 +1913,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); @@ -219,7 +219,7 @@ index f511e8d76..74b7f070c 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index 611341deb..ee463af9c 100644 +index c4529f5d9..8b0fb5d42 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -37,6 +37,41 @@ @@ -264,7 +264,7 @@ index 611341deb..ee463af9c 100644 #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) -@@ -891,6 +926,9 @@ struct ggml_cuda_pool { +@@ -938,6 +973,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; @@ -274,7 +274,7 @@ index 611341deb..ee463af9c 100644 }; template -@@ -1179,11 +1217,15 @@ struct ggml_backend_cuda_context { +@@ -1229,11 +1267,15 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; @@ -292,7 +292,7 @@ index 611341deb..ee463af9c 100644 } return *pools[device][curr_stream_no]; } -@@ -1191,6 +1233,22 @@ struct ggml_backend_cuda_context { +@@ -1241,6 +1283,22 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -316,10 +316,10 @@ index 611341deb..ee463af9c 100644 struct ggml_cuda_mm_fusion_args_host { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 78fb2d8b3..f1c178f31 100644 +index 17062697b..ede1d089a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -361,6 +361,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { +@@ -365,6 +365,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { // #define DEBUG_CUDA_MALLOC @@ -328,7 +328,7 @@ index 78fb2d8b3..f1c178f31 100644 // buffer pool for cuda (legacy) struct ggml_cuda_pool_leg : public ggml_cuda_pool { static const int MAX_BUFFERS = 256; -@@ -373,9 +375,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -377,9 +379,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; size_t pool_size = 0; @@ -343,7 +343,7 @@ index 78fb2d8b3..f1c178f31 100644 } ~ggml_cuda_pool_leg() { -@@ -383,7 +388,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -387,7 +392,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { @@ -354,7 +354,7 @@ index 78fb2d8b3..f1c178f31 100644 pool_size -= b.size; } } -@@ -431,8 +438,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -435,8 +442,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); @@ -372,7 +372,7 @@ index 78fb2d8b3..f1c178f31 100644 *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC -@@ -452,10 +466,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -456,10 +470,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } } GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); @@ -395,7 +395,7 @@ index 78fb2d8b3..f1c178f31 100644 }; // pool with virtual memory -@@ -467,18 +491,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { +@@ -471,18 +495,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { CUdeviceptr pool_addr = 0; size_t pool_used = 0; size_t pool_size = 0; @@ -423,7 +423,7 @@ index 78fb2d8b3..f1c178f31 100644 #if defined(GGML_USE_HIP) // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 for (std::pair & mapping : mappings) { -@@ -505,35 +535,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { +@@ -509,35 +539,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); @@ -499,7 +499,7 @@ index 78fb2d8b3..f1c178f31 100644 // add to the pool pool_size += reserve_size; -@@ -566,17 +610,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { +@@ -570,17 +614,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { // all deallocations must be in reverse order of the allocations GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } @@ -530,7 +530,7 @@ index 78fb2d8b3..f1c178f31 100644 } // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error -@@ -760,11 +814,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac +@@ -764,11 +818,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac } static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -552,7 +552,7 @@ index 78fb2d8b3..f1c178f31 100644 static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; -@@ -788,6 +851,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface +@@ -792,6 +855,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .is_host = */ NULL, @@ -560,7 +560,7 @@ index 78fb2d8b3..f1c178f31 100644 }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { -@@ -3258,6 +3322,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, +@@ -3274,6 +3338,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { @@ -568,7 +568,7 @@ index 78fb2d8b3..f1c178f31 100644 // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; -@@ -3347,6 +3412,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3410,6 +3475,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } @@ -579,7 +579,7 @@ index 78fb2d8b3..f1c178f31 100644 // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); -@@ -3691,6 +3760,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3754,6 +3823,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -587,7 +587,7 @@ index 78fb2d8b3..f1c178f31 100644 ggml_cuda_set_device(cuda_ctx->device); -@@ -3766,6 +3836,77 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, +@@ -3829,6 +3899,77 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } @@ -665,7 +665,7 @@ index 78fb2d8b3..f1c178f31 100644 static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; -@@ -4035,6 +4176,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { +@@ -4097,6 +4238,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, diff --git a/llama/patches/0021-decode-disable-output_all.patch b/llama/patches/0021-decode-disable-output_all.patch index c92e3910..7de5e378 100644 --- a/llama/patches/0021-decode-disable-output_all.patch +++ b/llama/patches/0021-decode-disable-output_all.patch @@ -8,7 +8,7 @@ Subject: [PATCH] decode: disable output_all 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp -index e04f0fc4f..1359c614b 100644 +index 417140071..87f407f99 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -999,8 +999,7 @@ int llama_context::decode(const llama_batch & batch_inp) { diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch index 04a6b0be..1bcc0e31 100644 --- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch @@ -43,7 +43,7 @@ index 7bdf9d81f..21b35ac5c 100644 struct ggml_backend_device { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 74b7f070c..8d2cc167f 100644 +index 4092dfe8a..a1a19fe51 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par @@ -62,10 +62,10 @@ index 74b7f070c..8d2cc167f 100644 GGML_ASSERT(device); return device->iface.get_buffer_type(device); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index f1c178f31..1110ca372 100644 +index ede1d089a..ec63cadab 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -109,6 +109,11 @@ int ggml_cuda_get_device() { +@@ -113,6 +113,11 @@ int ggml_cuda_get_device() { return id; } @@ -77,7 +77,7 @@ index f1c178f31..1110ca372 100644 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; -@@ -4386,7 +4391,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4448,7 +4453,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); @@ -89,7 +89,7 @@ index f1c178f31..1110ca372 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY -@@ -4841,6 +4849,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -4907,6 +4915,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -101,7 +101,7 @@ index f1c178f31..1110ca372 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -4857,6 +4870,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -4923,6 +4936,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0024-GPU-discovery-enhancements.patch index e4cebfae..86f57122 100644 --- a/llama/patches/0024-GPU-discovery-enhancements.patch +++ b/llama/patches/0024-GPU-discovery-enhancements.patch @@ -45,10 +45,10 @@ index 69223c488..6510e0cba 100644 GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 6d493a4ff..ac8f38464 100644 +index d55aed348..99ae293cc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -209,6 +209,8 @@ add_library(ggml-base +@@ -205,6 +205,8 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h @@ -58,10 +58,10 @@ index 6d493a4ff..ac8f38464 100644 set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 1110ca372..c1bfadb3e 100644 +index ec63cadab..cd71902df 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -263,6 +263,16 @@ static ggml_cuda_device_info ggml_cuda_init() { +@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -78,7 +78,7 @@ index 1110ca372..c1bfadb3e 100644 #if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); -@@ -316,6 +326,11 @@ static ggml_cuda_device_info ggml_cuda_init() { +@@ -320,6 +330,11 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; @@ -90,7 +90,7 @@ index 1110ca372..c1bfadb3e 100644 GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); -@@ -4255,6 +4270,11 @@ struct ggml_backend_cuda_device_context { +@@ -4317,6 +4332,11 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; std::string id; @@ -102,7 +102,7 @@ index 1110ca372..c1bfadb3e 100644 }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -4351,6 +4371,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { +@@ -4413,6 +4433,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); @@ -131,7 +131,7 @@ index 1110ca372..c1bfadb3e 100644 CUDA_CHECK(cudaMemGetInfo(free, total)); // ref: https://github.com/ggml-org/llama.cpp/pull/17368 -@@ -4383,6 +4425,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend +@@ -4445,6 +4487,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } @@ -139,7 +139,7 @@ index 1110ca372..c1bfadb3e 100644 static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; -@@ -4396,6 +4439,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4458,6 +4501,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; @@ -159,7 +159,7 @@ index 1110ca372..c1bfadb3e 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; -@@ -4980,6 +5036,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -5046,6 +5102,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; @@ -167,7 +167,7 @@ index 1110ca372..c1bfadb3e 100644 for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; -@@ -4995,6 +5052,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -5061,6 +5118,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; @@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644 /* .async = */ true, /* .host_buffer = */ false, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index a36c6560c..a234eda2e 100644 +index b2c0d0cee..d9f4d34f5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -236,6 +236,7 @@ class vk_memory_logger; @@ -254,7 +254,7 @@ index a36c6560c..a234eda2e 100644 static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t p021_max_gqa_ratio = 8; -@@ -12353,6 +12354,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ +@@ -12256,6 +12257,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } @@ -284,7 +284,7 @@ index a36c6560c..a234eda2e 100644 // backend interface #define UNUSED GGML_UNUSED -@@ -13614,15 +13638,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size +@@ -13535,15 +13559,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } @@ -361,7 +361,7 @@ index a36c6560c..a234eda2e 100644 if (membudget_supported) { memprops.pNext = &budgetprops; -@@ -13674,8 +13755,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +@@ -13595,8 +13676,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { } } @@ -376,7 +376,7 @@ index a36c6560c..a234eda2e 100644 } vk::PhysicalDeviceProperties2 props = {}; -@@ -13692,19 +13778,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +@@ -13613,19 +13699,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); @@ -410,7 +410,7 @@ index a36c6560c..a234eda2e 100644 static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -@@ -13716,9 +13807,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -13637,9 +13728,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } @@ -426,7 +426,7 @@ index a36c6560c..a234eda2e 100644 } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { -@@ -13742,8 +13838,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -13663,8 +13759,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -437,7 +437,7 @@ index a36c6560c..a234eda2e 100644 ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, -@@ -13751,6 +13848,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -13672,6 +13769,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; @@ -451,7 +451,7 @@ index a36c6560c..a234eda2e 100644 } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { -@@ -14319,6 +14423,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14236,6 +14340,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { @@ -460,7 +460,7 @@ index a36c6560c..a234eda2e 100644 for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; -@@ -14327,12 +14433,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14244,12 +14350,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; diff --git a/llama/patches/0026-report-LoadLibrary-failures.patch b/llama/patches/0026-report-LoadLibrary-failures.patch index 2adec160..7f0e9be9 100644 --- a/llama/patches/0026-report-LoadLibrary-failures.patch +++ b/llama/patches/0026-report-LoadLibrary-failures.patch @@ -8,10 +8,10 @@ Subject: [PATCH] report LoadLibrary failures 1 file changed, 12 insertions(+) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index a55d9b280..ec6f7f1e9 100644 +index 079dba211..2474e0ed6 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -122,6 +122,18 @@ static dl_handle * dl_load_library(const fs::path & path) { +@@ -126,6 +126,18 @@ static dl_handle * dl_load_library(const fs::path & path) { SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); HMODULE handle = LoadLibraryW(path.wstring().c_str()); diff --git a/llama/patches/0027-interleave-multi-rope.patch b/llama/patches/0027-interleave-multi-rope.patch index 7d36d355..6ca94029 100644 --- a/llama/patches/0027-interleave-multi-rope.patch +++ b/llama/patches/0027-interleave-multi-rope.patch @@ -13,7 +13,7 @@ interleaved version used for qwen3vl 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 40666bab6..3155cb4bb 100644 +index 7d1733adb..f4aae5332 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5599,14 +5599,14 @@ static void ggml_mrope_cache_init( @@ -59,10 +59,10 @@ index 88ed79111..71ca60214 100644 } else { if (sector < sections.v[0]) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 8a6c834d1..761b57a26 100644 +index 236838e9e..c98d269d1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi( +@@ -4242,14 +4242,14 @@ kernel void kernel_rope_multi( float theta_base; if (FC_rope_is_imrope) { diff --git a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch b/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch index 17656838..f8106f0f 100644 --- a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch +++ b/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch @@ -12,10 +12,10 @@ Subject: [PATCH] Add memory detection using DXGI + PDH create mode 100644 ggml/src/mem_dxgi_pdh.cpp diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index ac8f38464..faa1beed2 100644 +index 99ae293cc..9a134b7af 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -211,6 +211,7 @@ add_library(ggml-base +@@ -207,6 +207,7 @@ add_library(ggml-base ggml-quants.h mem_hip.cpp mem_nvml.cpp @@ -38,7 +38,7 @@ index 1c07e767a..0da3e065b 100644 #ifdef __cplusplus } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index a234eda2e..c98f98c73 100644 +index d9f4d34f5..8a83427fb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -49,7 +49,7 @@ index a234eda2e..c98f98c73 100644 typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { VkStructureType sType; -@@ -13655,6 +13656,7 @@ struct ggml_backend_vk_device_context { +@@ -13576,6 +13577,7 @@ struct ggml_backend_vk_device_context { std::string pci_id; std::string id; std::string uuid; @@ -57,7 +57,7 @@ index a234eda2e..c98f98c73 100644 int major; int minor; int driver_major; -@@ -13673,6 +13675,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size +@@ -13594,6 +13596,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size vk::PhysicalDeviceProperties2 props2; vkdev.getProperties2(&props2); @@ -78,7 +78,7 @@ index a234eda2e..c98f98c73 100644 if (!is_integrated_gpu) { -@@ -13704,7 +13720,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size +@@ -13625,7 +13641,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size } // else fallback to memory budget if supported @@ -86,7 +86,7 @@ index a234eda2e..c98f98c73 100644 if (membudget_supported) { memprops.pNext = &budgetprops; } -@@ -14440,7 +14455,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14357,7 +14372,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, /* .reg = */ reg, /* .context = */ ctx, }); @@ -94,7 +94,7 @@ index a234eda2e..c98f98c73 100644 // Gather additional information about the device int dev_idx = vk_instance.device_indices[i]; vk::PhysicalDeviceProperties props1; -@@ -14463,6 +14477,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14380,6 +14394,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, } } ctx->uuid = oss.str(); diff --git a/llama/patches/0029-ggml-cuda-skip-large-batches.patch b/llama/patches/0029-ggml-cuda-skip-large-batches.patch index 834b6e9d..86f1840c 100644 --- a/llama/patches/0029-ggml-cuda-skip-large-batches.patch +++ b/llama/patches/0029-ggml-cuda-skip-large-batches.patch @@ -10,10 +10,10 @@ fallback to cpu 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index c1bfadb3e..16c166a08 100644 +index cd71902df..d69d62193 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -4570,6 +4570,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g +@@ -4632,6 +4632,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/llama/patches/0030-win-exit-instead-of-abort.patch b/llama/patches/0030-win-exit-instead-of-abort.patch index 9f1a65ea..7dc156e4 100644 --- a/llama/patches/0030-win-exit-instead-of-abort.patch +++ b/llama/patches/0030-win-exit-instead-of-abort.patch @@ -8,10 +8,10 @@ Subject: [PATCH] win: exit instead of abort 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c -index b99345a2e..1c9e0bc05 100644 +index 530ff7b95..fc0196eb7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c -@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { +@@ -250,8 +250,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { fprintf(stderr, "%s\n", message); ggml_print_backtrace(); } diff --git a/llama/patches/0031-fix-bakllava-regression.patch b/llama/patches/0031-fix-bakllava-regression.patch index 9481f87a..fa306191 100644 --- a/llama/patches/0031-fix-bakllava-regression.patch +++ b/llama/patches/0031-fix-bakllava-regression.patch @@ -9,10 +9,10 @@ Rever to prior logic of assuming an empty projector type is mlp 1 file changed, 4 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index f4c4d2c48..3334ff25b 100644 +index 6be1470ad..2a325c726 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp -@@ -2648,6 +2648,10 @@ struct clip_model_loader { +@@ -2649,6 +2649,10 @@ struct clip_model_loader { if (proj_type.empty()) { if (modality == CLIP_MODALITY_VISION) { get_string(KEY_VISION_PROJ_TYPE, proj_type, false); diff --git a/ml/backend/ggml/ggml/include/ggml-rpc.h b/ml/backend/ggml/ggml/include/ggml-rpc.h index 832c26c6..df1ad2a5 100644 --- a/ml/backend/ggml/ggml/include/ggml-rpc.h +++ b/ml/backend/ggml/ggml/include/ggml-rpc.h @@ -1,6 +1,5 @@ #pragma once -#include "ggml.h" #include "ggml-backend.h" #ifdef __cplusplus @@ -8,7 +7,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 5 +#define RPC_PROTO_MINOR_VERSION 6 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index 4dbca868..6bc762c0 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -204,6 +204,10 @@ # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif +#if defined(_WIN32) && !defined(_WIN32_WINNT) +# define _WIN32_WINNT 0x0A00 +#endif + #include #include #include @@ -2148,7 +2152,8 @@ extern "C" { }; enum ggml_scale_flag { - GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), }; // interpolate @@ -2191,6 +2196,15 @@ extern "C" { int p2, int p3); + // pad each dimension with values on the other side of the torus (looping around) + GGML_API struct ggml_tensor * ggml_pad_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2204,6 +2218,19 @@ extern "C" { int rp3 ); + // pad each dimension with values on the other side of the torus (looping around) + GGML_API struct ggml_tensor * ggml_pad_ext_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] GGML_API struct ggml_tensor * ggml_pad_reflect_1d( struct ggml_context * ctx, @@ -2278,7 +2305,7 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 1 // q: [n_embd_k, n_batch, n_head, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ] diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index faa1beed..9a134b7a 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -127,10 +127,6 @@ if (NOT MSVC) endif() endif() -if (MINGW) - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # POSIX conformance # @@ -445,6 +441,7 @@ ggml_add_backend(WebGPU) ggml_add_backend(zDNN) ggml_add_backend(OpenCL) ggml_add_backend(Hexagon) +ggml_add_backend(ZenDNN) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 06ee502a..dbfd8b5b 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -25,6 +25,7 @@ static bool ggml_is_view(const struct ggml_tensor * t) { // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index ec6f7f1e..2474e0ed 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -73,6 +73,10 @@ #include "ggml-cann.h" #endif +#ifdef GGML_USE_ZENDNN +#include "ggml-zendnn.h" +#endif + // disable C++17 deprecation warning for std::codecvt_utf8 #if defined(__clang__) # pragma clang diagnostic push @@ -215,6 +219,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif +#ifdef GGML_USE_ZENDNN + register_backend(ggml_backend_zendnn_reg()); +#endif #ifdef GGML_USE_HEXAGON register_backend(ggml_backend_hexagon_reg()); #endif @@ -551,8 +558,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, fs::path best_path; for (const auto & search_path : search_paths) { - if (!fs::exists(search_path)) { - GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); + if (std::error_code ec; !fs::exists(search_path, ec)) { + if (ec) { + GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); + } else { + GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); + } continue; } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); @@ -592,8 +603,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, for (const auto & search_path : search_paths) { fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native(); fs::path path = search_path / filename; - if (fs::exists(path)) { + if (std::error_code ec; fs::exists(path, ec)) { return get_reg().load_backend(path, silent); + } else { + if (ec) { + GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str()); + } } } return nullptr; @@ -614,6 +629,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { #endif ggml_backend_load_best("blas", silent, dir_path); + ggml_backend_load_best("zendnn", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path); diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 8d2cc167..a1a19fe5 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -754,6 +754,12 @@ struct ggml_backend_sched { int debug; + // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC] + // ref: https://github.com/ggml-org/llama.cpp/pull/17617 + int debug_realloc; + int debug_graph_size; + int debug_prev_graph_size; + // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation // if false, dummy buffers are used for faster memory sizing calculations // the scheduler needs to be recreated with allocated buffers before it can be used @@ -1270,10 +1276,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); } - if (sched->n_copies > 1) { - ggml_set_input(tensor_copy); - ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor - } + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } @@ -1325,6 +1329,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; + + // remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC] + sched->debug_prev_graph_size = sched->debug_graph_size; + sched->debug_graph_size = graph_size; + if (sched->graph.size < graph_size) { sched->graph.size = graph_size; sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); @@ -1431,14 +1440,21 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { -#ifdef GGML_SCHED_NO_REALLOC - GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__); -#endif - #ifndef NDEBUG GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif + if (sched->debug_realloc > 0) { + // we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC] + // example: https://github.com/ggml-org/llama.cpp/pull/17143 + const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size; + + if (unexpected || sched->debug_realloc > 1) { + GGML_ABORT("%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\n", __func__, + sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc); + } + } + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { @@ -1667,6 +1683,14 @@ ggml_backend_sched_t ggml_backend_sched_new_ext( const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG"); sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0; + + sched->debug_realloc = 0; +#ifdef GGML_SCHED_NO_REALLOC + sched->debug_realloc = 1; +#endif + const char * GGML_SCHED_DEBUG_REALLOC = getenv("GGML_SCHED_DEBUG_REALLOC"); + sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc; + sched->n_backends = n_backends; sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; @@ -1683,6 +1707,9 @@ ggml_backend_sched_t ggml_backend_sched_new_ext( sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->debug_graph_size = 0; + sched->debug_prev_graph_size = 0; + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer = (char *) malloc(sched->context_buffer_size); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt index 7e53a57b..fc31089f 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt @@ -469,6 +469,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() + if (GGML_RV_ZIHINTPAUSE) + string(APPEND MARCH_STR "_zihintpause") + endif() list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp index 67369147..c460c549 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp @@ -8,6 +8,10 @@ #include #endif +#if !defined(HWCAP2_SVE2) +#define HWCAP2_SVE2 (1 << 1) +#endif + #if !defined(HWCAP2_I8MM) #define HWCAP2_I8MM (1 << 13) #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp index 082bd2bf..683ed8d2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -505,7 +505,6 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo constexpr int blocklen = 8; assert(n % qk == 0); - assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); UNUSED(nb); @@ -645,7 +644,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, constexpr int blocklen = 8; assert(n % qk == 0); - assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); UNUSED(nb); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 7a0df30c..47089a62 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -492,6 +492,15 @@ static inline void ggml_thread_cpu_relax(void) { static inline void ggml_thread_cpu_relax(void) { _mm_pause(); } +#elif defined(__riscv) +static inline void ggml_thread_cpu_relax(void) { + #ifdef __riscv_zihintpause + __asm__ __volatile__ ("pause"); + #else + /* Encoding of the pause instruction */ + __asm__ __volatile__ (".4byte 0x100000F"); + #endif +} #else static inline void ggml_thread_cpu_relax(void) {;} #endif @@ -685,22 +694,14 @@ bool ggml_is_numa(void) { } #if defined(__ARM_ARCH) - -#if defined(__linux__) && defined(__aarch64__) -#include -#endif - -static void ggml_init_arm_arch_features(void) { #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) -#if defined(__linux__) - ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); -#else - // TODO: add support of SVE for non-linux systems -#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." -#endif -#endif +#include +static void ggml_init_arm_arch_features(void) { + ggml_arm_arch_features.sve_cnt = svcntb(); } - +#else +static void ggml_init_arm_arch_features(void) {} +#endif #endif // __ARM_ARCH struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { @@ -2708,6 +2709,11 @@ struct ggml_cplan ggml_graph_plan( n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } +#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) + // Emscripten without pthreads support can only use a single thread + n_threads = 1; +#endif + size_t work_size = 0; struct ggml_cplan cplan; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 2c4ad9d5..a0cce10a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -117,8 +117,7 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; +#include "sgemm-ppc.h" #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -1573,95 +1572,35 @@ class tinyBLAS_BF16_PPC { const int nth; }; -template -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) + template + tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + kc = 64; } - void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); - } - - private: - - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - template - inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - /* This function processes quantized data from block_q4_0 elements. - * First the we try to extract the two int4 values stored in single int8_t into two signed int8. - * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8. - * Also compute the rowsum which is required to compensate the above conversion. */ - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + template + void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) { + int mc = 64; int nc = 64; + if (n % 8 == 0 && n < nc) { + nc = n; + mc = 32 ; + kc = 32; + } + const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); + if (is_aligned) { + this->matmul_tiled_q0(m, n, mc, nc, kc); + } else { + mnpack(0, m, 0, n); } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); } - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + template + template + void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { int64_t i, j; TA *aoffset = NULL; int8_t *vecOffset = NULL; @@ -1781,8 +1720,10 @@ class tinyBLAS_Q0_PPC { } } } + + template template - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { int64_t i, j; block_q8_0 *aoffset = NULL; VA *vecOffset = NULL; @@ -1822,7 +1763,6 @@ class tinyBLAS_Q0_PPC { j--; } while(j > 0); } - if (rows & 4) { aoffsets[0] = aoffset; for (int it = 1; it < 4; it++ ) @@ -1878,7 +1818,8 @@ class tinyBLAS_Q0_PPC { } } - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + template + void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -1915,7 +1856,8 @@ class tinyBLAS_Q0_PPC { } - void KERNEL_4x8(int64_t ii, int64_t jj) { + template + void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -1953,14 +1895,15 @@ class tinyBLAS_Q0_PPC { aoffset += lda; } } - compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); - compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); + compute(&acc_0, 0, 0, comparray, vs, fin_res); + compute(&acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); save_res(ii, jj+4, 4, fin_res); } - void KERNEL_8x4(int64_t ii, int64_t jj) { + template + void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -1997,16 +1940,18 @@ class tinyBLAS_Q0_PPC { aoffset += lda; } } - compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); - compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); + compute(&acc_0, 0, 0, comparray, vs, fin_res); + compute(&acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); save_res(ii+4, jj, 4, fin_res); } - void KERNEL_8x8(int64_t ii, int64_t jj) { + template + void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; + acc_t acc_4, acc_5, acc_6, acc_7; std::array comparray {}; vector float fin_res[16] = {0}; vector float vs[16] = {0}; @@ -2046,10 +1991,10 @@ class tinyBLAS_Q0_PPC { aoffset += lda; } } - compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); - compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); - compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); - compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); + compute(&acc_0, 0, 0, comparray, vs, fin_res); + compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(&acc_2, 0, 8, comparray, vs, fin_res); + compute(&acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); save_res(ii+4, jj, 4, fin_res); @@ -2057,7 +2002,8 @@ class tinyBLAS_Q0_PPC { save_res(ii+4, jj+4, 12, fin_res); } - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + template + void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2125,21 +2071,9 @@ class tinyBLAS_Q0_PPC { } } - template - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - + template template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2151,20 +2085,12 @@ class tinyBLAS_Q0_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - kernel(ii, jj); + this->kernel(ii, jj); } } - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; +template class tinyBLAS_Q0_PPC; +template class tinyBLAS_Q0_PPC; class tinyBLAS_PPC { public: diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.h b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.h index 729e8853..867b0c04 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.h @@ -6,6 +6,12 @@ #include #endif +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + #ifdef __cplusplus extern "C" { #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index 3155cb4b..f4aae533 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16( const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iid = iod*s2 + ikd*d2 - p2; - if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] @@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params ggml_compute_forward_mul_mat(params, &dst); } +static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { + return (coord + size) % size; // adding size avoids negative number weirdness +} + // ggml_compute_forward_conv_2d + static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params, const ggml_tensor * kernel, // [KW, KH, IC, OC] const ggml_tensor * src, // [W, H, C, N] @@ -7420,6 +7425,65 @@ static void ggml_compute_forward_upscale_f32( } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) { + // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) + // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp + auto triangle_filter = [](float x) -> float { + return std::max(1.0f - fabsf(x), 0.0f); + }; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = std::max(1.0f, 1.0f / sf1); + const float invscale1 = 1.0f / support1; + const float support0 = std::max(1.0f, 1.0f / sf0); + const float invscale0 = 1.0f / support0; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float) i1 + pixel_offset) / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float) i0 + pixel_offset) / sf0; + + // the range of source pixels that contribute + const int64_t x_min = std::max(x - support0 + pixel_offset, 0); + const int64_t x_max = std::min(x + support0 + pixel_offset, ne00); + const int64_t y_min = std::max(y - support1 + pixel_offset, 0); + const int64_t y_max = std::min(y + support1 + pixel_offset, ne01); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *dst_ptr = val; + } + } + } + } } else if (mode == GGML_SCALE_MODE_BILINEAR) { for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; @@ -7532,6 +7596,7 @@ void ggml_compute_forward_upscale( // ggml_compute_forward_pad +template static void ggml_compute_forward_pad_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7556,23 +7621,40 @@ static void ggml_compute_forward_pad_f32( const int32_t lp3 = ggml_get_op_params_i32(dst, 6); const int32_t rp3 = ggml_get_op_params_i32(dst, 7); - // TODO: optimize for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = ith; i1 < ne1; i1 += nth) { for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i3 = 0; i3 < ne3; ++i3) { - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - if ((i0 >= lp0 && i0 < ne0 - rp0) \ - && (i1 >= lp1 && i1 < ne1 - rp1) \ - && (i2 >= lp2 && i2 < ne2 - rp2) \ - && (i3 >= lp3 && i3 < ne3 - rp3)) { - const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + // circular means wrap around on a torus, so x and y loop around + if constexpr (circular_t) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00); + const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01); + const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02); + const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03); + + const int64_t src_idx = + src_i3*nb03 + + src_i2*nb02 + + src_i1*nb01 + + src_i0*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); dst_ptr[dst_idx] = *src_ptr; } else { - dst_ptr[dst_idx] = 0; + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) \ + && (i1 >= lp1 && i1 < ne1 - rp1) \ + && (i2 >= lp2 && i2 < ne2 - rp2) \ + && (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } } } } @@ -7580,16 +7662,20 @@ static void ggml_compute_forward_pad_f32( } } + void ggml_compute_forward_pad( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - + const bool circular = (bool) ggml_get_op_params_i32(dst, 8); switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_pad_f32(params, dst); + if (circular) { + ggml_compute_forward_pad_f32(params, dst); + } else { + ggml_compute_forward_pad_f32(params, dst); + } } break; default: { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index ee463af9..8b0fb5d4 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -261,7 +261,7 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) -#if defined(GGML_USE_HIP) && defined(RDNA4) +#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3)) #define AMD_WMMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(RDNA4) @@ -329,7 +329,7 @@ static bool amd_mfma_available(const int cc) { } static bool amd_wmma_available(const int cc) { - return GGML_CUDA_CC_IS_RDNA4(cc); + return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)); } static bool volta_mma_available(const int cc) { @@ -498,6 +498,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(0xffffffff, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width); + const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const half2 t = __shfl_up_sync(0xffffffff, a, offset, width); + if (lane_id >= offset) { + a = __hadd2(a, t); + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE @@ -1027,6 +1074,10 @@ struct ggml_cuda_concurrent_event { int n_streams = 0; std::unordered_map stream_mapping; + // Original order of nodes in this concurrent region (before interleaving) + // Used to restore grouping for fusion within streams + std::vector original_order; + const ggml_tensor * join_node; ggml_cuda_concurrent_event() = default; @@ -1049,6 +1100,7 @@ struct ggml_cuda_concurrent_event { , fork_event(other.fork_event) , n_streams(other.n_streams) , stream_mapping(std::move(other.stream_mapping)) + , original_order(std::move(other.original_order)) , join_node(other.join_node) { other.fork_event = nullptr; } @@ -1159,11 +1211,9 @@ struct ggml_cuda_concurrent_event { }; struct ggml_cuda_stream_context { - std::vector original_nodes; std::unordered_map concurrent_events; void reset() { - original_nodes.clear(); concurrent_events.clear(); } }; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 00000000..d2f2def8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,237 @@ +#include +#include "cumsum.cuh" +#include "convert.cuh" +#include "ggml-cuda/common.cuh" +#include "ggml.h" + +#ifdef GGML_CUDA_USE_CUB +# include +#endif // GGML_CUDA_USE_CUB + +template +static __global__ void cumsum_cub_kernel( + const T * __restrict__ src, + T * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s1, const int64_t s2, const int64_t s3) { +#ifdef GGML_CUDA_USE_CUB + using BlockScan = cub::BlockScan; + + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ T block_carry; // carry from previous tile + + const int tid = threadIdx.x; + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + const int64_t i3 = blockIdx.z; + + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; + + if (tid == 0) { + block_carry = 0; + } + __syncthreads(); + + for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { + int64_t idx = start + tid; + T x = (idx < ne00) ? src_row[idx] : T(0); + + T inclusive; + T block_total; + BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total); + + __syncthreads(); + + T final_val = inclusive + block_carry; + + // store result + if (idx < ne00) { + dst_row[idx] = final_val; + } + + __syncthreads(); + + if (tid == 0) { + block_carry += block_total; + } + + __syncthreads(); + } +#else + NO_DEVICE_CODE; +#endif // GGML_CUDA_USE_CUB +} + +// Fallback kernel implementation (original) +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) { + + GGML_UNUSED_VARS(s00, s0); + + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int lane = tid % warp_size; + const int warp = tid / warp_size; + const int warps_per_block = blockDim.x / warp_size; + + extern __shared__ float smem[]; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; + + // Initialize carry + if (tid == 0) { + *s_carry = 0.0f; + } + __syncthreads(); + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; + + for (int64_t start = 0; start < ne00; start += blockDim.x) { + int64_t idx = start + tid; + float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; + + // 1. Warp inclusive scan + val = warp_prefix_inclusive_sum(val); + s_vals[tid] = val; + + // Store warp total + if (lane == warp_size - 1) { + s_warp_sums[warp] = val; + } + __syncthreads(); + + // 2. Exclusive scan of warp sums (warp 0 only) + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum(w); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; // exclusive sum + } + if (tid == warps_per_block - 1) { + *s_chunk_total = inc; // total sum of this chunk + } + } + __syncthreads(); + + float carry = *s_carry; + float final_val = s_vals[tid] + s_warp_sums[warp] + carry; + if (idx < ne00) { + dst_row[idx] = ggml_cuda_cast(final_val); + } + __syncthreads(); + + // Update carry for next chunk + if (tid == 0) { + *s_carry += *s_chunk_total; + } + __syncthreads(); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + const size_t type_size = sizeof(T); + bool use_cub = false; +#ifdef GGML_CUDA_USE_CUB + // Check if we can use CUB (data must be contiguous along innermost dimension) + const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); + + if (is_contiguous) { + use_cub = true; + } +#endif // GGML_CUDA_USE_CUB + dim3 grid_dims(ne01, ne02, ne03); + const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()]; + const int warp_size = info.warp_size; + const int num_warps = (ne00 + warp_size - 1) / warp_size; + int block_size = num_warps * warp_size; + block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); + dim3 block_dims(block_size, 1, 1); + const int warps_per_block = block_size / warp_size; + const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); + + if (use_cub) { + cumsum_cub_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms + /*case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break;*/ + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 00000000..782d1d92 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/diag.cu b/ml/backend/ggml/ggml/src/ggml-cuda/diag.cu new file mode 100644 index 00000000..5cea2105 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/diag.cu @@ -0,0 +1,77 @@ +#include "convert.cuh" +#include "diag.cuh" +#include "ggml.h" + +template +static __global__ void diag_kernel(T * __restrict__ dst, + const T * __restrict__ src, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const int64_t total_elements) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= total_elements) { + return; + } + + const int64_t i0 = global_idx % ne0; + const int64_t i1 = (global_idx / ne0) % ne1; + const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2; + const int64_t i3 = global_idx / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + const int64_t src_idx = batch_idx * ne0 + i0; + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = ggml_cuda_cast(0); + } + GGML_UNUSED_VARS(ne3); +} + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + void * dst_d = dst->data; + const void * src0_d = src0->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + diag_kernel<<>>((float *) dst_d, (const float *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + case GGML_TYPE_F16: + diag_kernel<<>>((half *) dst_d, (const half *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/diag.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/diag.cuh new file mode 100644 index 00000000..7d73e6a8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/diag.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_DIAG_BLOCK_SIZE 256 + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh index 5cdd4bb2..2750117a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh @@ -10,6 +10,12 @@ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable +// by the VKQ accumulators is effectively being shifted up by a factor of 8. +// This reduces issues with numerical overflow but also causes larger values to be flushed to zero. +// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. +#define FATTN_KQ_MAX_OFFSET 0.6931f + typedef void (* fattn_kernel_t)( const char * __restrict__ Q, const char * __restrict__ K, @@ -25,7 +31,7 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -621,7 +627,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, + const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -632,8 +639,8 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -765,7 +772,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -790,8 +797,6 @@ void launch_fattn( GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); @@ -915,7 +920,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -970,6 +975,9 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, @@ -980,7 +988,7 @@ void launch_fattn( KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, @@ -995,7 +1003,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 57defb0c..d51537f7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,284 +5,211 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, half2> tile_B_16; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 16, float> tile_C_KQ_16; -typedef tile<16, 4, half2> tile_C_VKQ; -typedef tile<16, 8, half2> tile_C_VKQ_16; - -// Config options for specific head sizes. +// Config options for the MMA kernel. // Should not affect results, only speed/register pressure/shared memory use. -// -// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. -// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). -// Q_in_reg: whether the Q values should be kept permanently in registers. -// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. -// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. -// nbatch_V2: number of V half2 values in direction of DV to load in parallel. -// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. +struct fattn_mma_config { + int nthreads; // Number of threads per CUDA block. + int occupancy; // Targeted occupancy for the MMA kernel. + int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. + int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel. + int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel. + int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel. + int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support. + bool Q_in_reg; // Whether the Q values should be kept permanently in registers. -template -struct fattn_mma_f16_config; - -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } + constexpr __host__ __device__ fattn_mma_config( + int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) : + nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine), + nstages_target(nstages_target), Q_in_reg(Q_in_reg) {} }; -template <> -struct fattn_mma_f16_config< 80, 80> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \ + static_assert( (occupancy_) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \ + return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \ + } \ - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 40; +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 40; + if (turing_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 40; - } -}; - -template <> -struct fattn_mma_f16_config< 96, 96> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 48; - } -}; - -template <> -struct fattn_mma_f16_config<112, 112> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 56; - } -}; - -template <> -struct fattn_mma_f16_config<128, 128> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 64; - } -}; - -template <> -struct fattn_mma_f16_config<256, 256> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 128 : 64; +static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } -}; + GGML_UNUSED_VARS(DKQ, DV, ncols); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +#endif // defined(AMPERE_MMA_AVAILABLE) +} -template <> -struct fattn_mma_f16_config<576, 512> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 8; - static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads; +} - static int get_nbatch_K2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 96 : 160; - } - return ncols <= 16 ? 288 : 160; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads; +} - static constexpr __device__ int get_nbatch_K2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 96 : 160; -#else - return ncols <= 16 ? 288 : 160; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy; +} - static int get_nbatch_V2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 64 : 128; - } - return ncols <= 16 ? 256 : 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy; +} - static constexpr __device__ int get_nbatch_V2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 64 : 128; -#else - return ncols <= 16 ? 256 : 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa; +} - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa; +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 128; - } -}; +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine; +} + +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target; +} + +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg; +} + +static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; +} // ------------------------------------------------------------------------------------------------------------------ -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; +#else + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} + +// ------------------------------------------------------------------------------------------------------------------ + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - - if (use_cp_async) { + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } }; - ggml_cuda_unroll<5>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); @@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } } }; - ggml_cuda_unroll<3>{}(load); + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( - const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - - if (use_cp_async) { + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); } - return; - } - - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; - constexpr int stride_j = nwarps * cols_per_warp; + } else if constexpr (oob_check) { #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } - tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, const int stride_K, const int stride_V, @@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, - half2 * const __restrict__ tile_mask, - const tile_B * const __restrict__ Q_B, - tile_C_VKQ * const __restrict__ VKQ_C, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, - const int kb0) { -#ifdef TURING_MMA_AVAILABLE - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; @@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; - const int k_VKQ_0 = kb0 * c::nbatch_fa; - tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; - - // Use wide variants of tiles if ntiles >= 2. - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; - tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } } @@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); if (use_cp_async) { cp_async_wait_all(); } @@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Calculate tile of KQ: - if constexpr (c::Q_in_reg) { + if constexpr (Q_in_reg) { #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - tile_A K_A; + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); } } } } else { - static_assert(ntiles == 2, "ntiles != 2 not implemented"); + static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; - tile_A K_A; + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); } } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } if (use_logit_softcap) { - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < nbatch_fa/stride; ++i) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { + for (int l = 0; l < T_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } @@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); + } } } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - - KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); + KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } } } - } else { // ntiles > 1 - if (ncols2 > 1 || mask_h2) { + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { - const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; - const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); - const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; - KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; - KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; - } + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + // Turing + Volta: + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } } - // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#else + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll - for (int offset = 2; offset >= 1; offset >>= 1) { + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } - static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - - KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); - - KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + // Turing + Volta: + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); + KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } } } @@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; - tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); - } + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } } @@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; const int i0_diff = i0_stop - i0_start; - if (nstages <= 1 && i0_start < reusable_cutoff) { - constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { - cp_async_wait_all(); + if constexpr (nstages <= 1) { + if (i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); } - __syncthreads(); } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; - // Calculate VKQ tile: +#if defined(TURING_MMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; - tile_A A; + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); - } + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); } } } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; - if (nstages <= 1) { + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) + + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template +#if defined(TURING_MMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#else // Volta +template struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, + const int ne11, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef TURING_MMA_AVAILABLE +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); @@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; - half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; - half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; - half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); - tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; - tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; - - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) float KQ_rowsum[cols_per_thread] = {0.0f}; float KQ_max[cols_per_thread]; @@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if (jt*ncols1 + j < int(ne01.z)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -889,63 +928,93 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if (Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); - } - } + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } } __syncthreads(); + int kb0 = kb0_start; + // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - // Iterate over ne11 == previous tokens: - int kb0 = kb0_start; - for (; kb0 < kb0_stop-1; ++kb0) { - constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); - } - { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + constexpr bool oob_check = true; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool oob_check = false; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + constexpr bool last_iter = true; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { __syncthreads(); } // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -962,8 +1031,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -977,12 +1045,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -991,30 +1060,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); - constexpr int tile_stride = nbatch_combine + 4; + constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); - if constexpr (ntiles == 1) { - const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset - const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1023,24 +1102,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { + if (needs_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } else { - static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); - const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta - + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) - + tile_C_VKQ_16::get_i(threadIdx.x % 4); - const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) - if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1048,18 +1133,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (needs_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (is_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } - static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. @@ -1135,32 +1219,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); tile_Q[jc_cwd*tile_stride + k] = B.x[l]; } } } else { + const int j0 = threadIdx.y*cols_per_warp; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; } } } @@ -1195,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { continue; } @@ -1233,16 +1314,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template -__launch_bounds__(nwarps*WARP_SIZE, 1) +template +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -1258,14 +1339,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1281,23 +1362,22 @@ static __global__ void flash_attn_ext_f16( static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - typedef fattn_mma_f16_config c; - - static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); const int stride_V = mla ? stride_K : nb21 / sizeof(half2); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -1318,35 +1398,31 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } - constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { - constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1366,29 +1442,26 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1400,7 +1473,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1409,36 +1482,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; - constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; - constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; - const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); - - static_assert(DKQ % tile_B::J == 0, "bad DKQ"); - static_assert(DV % tile_A::J == 0, "bad DV"); - static_assert(ncols % cols_per_warp == 0, "bad ncols"); - - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); @@ -1448,7 +1515,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1459,7 +1526,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1471,7 +1538,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 3e58d64f..7c4d6fe6 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, + const uint3 ne01, const float logit_softcap, const float slope, T_KQ * const KQ, @@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float * const KQ_sum, T_acc * const VKQ, const int k_VKQ_0, - const int k_VKQ_max) { + const int k_VKQ_max, + const int col_Q_0) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -556,12 +558,18 @@ static __device__ __forceinline__ void flash_attn_tile_iter( // Apply logit softcap + mask, update KQ_max: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { - const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; +#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f; +#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + if (use_logit_softcap) { KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); } @@ -570,7 +578,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; - KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET); } } @@ -736,7 +744,7 @@ static __global__ void flash_attn_tile( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -781,11 +789,11 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / (ne02/ncols2); const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); const int stride_V2 = nb21 / sizeof(half2); @@ -842,11 +850,9 @@ static __global__ void flash_attn_tile( for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { float tmp_f[cpy_ne_D] = {0.0f}; - if (ncols1 == 1 || col_Q_0 + j < ne01) { - ggml_cuda_memcpy_1 - (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) - + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); - } + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -858,6 +864,11 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); +#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + tmp_h2[i1/2] *= make_half2(0.25f, 0.25f); +#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) } ggml_cuda_memcpy_1( &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], @@ -881,23 +892,23 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1010,13 +1021,13 @@ static __global__ void flash_attn_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (ncols1 > 1 && col_Q_0 + j >= ne01) { + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { return; } const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; - const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh index 67aa67ec..4d167b95 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh @@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); // Set memory to zero if out of bounds: - if (ncols > 1 && ic0 + j >= ne01) { + if (ncols > 1 && ic0 + j >= int(ne01.z)) { #pragma unroll for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec( const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); } @@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); } @@ -266,11 +266,11 @@ static __global__ void flash_attn_ext_vec( sum = logit_softcap*tanhf(sum); } - if (mask) { + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } - KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); + KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET); if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; @@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 1 && ic0 + j_VKQ >= ne01) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { break; } @@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec( if (gridDim.y == 1) { dst_val /= KQ_sum[j_VKQ]; } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; } } @@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec( } - if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d5..8694fd06 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + warp_size > D && i >= D) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -218,8 +218,9 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { + if (ic0 + j_VKQ >= int(ne01.z)) { return; } @@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { @@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) } constexpr int get_max_power_of_2(int x) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 7235f1b7..cd3bfd40 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,9 +2,9 @@ #include "common.cuh" -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#if defined(GGML_USE_MUSA) #define GGML_USE_WMMA_FATTN -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 82405991..01554066 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { - if (Q->ne[1] <= 8/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (Q->ne[1] <= 16/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f; + // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers + // are put into the template specialization without GQA optimizations. + bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + use_gqa_opt = false; + break; + } + } + } GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -275,8 +289,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores available, use them: - if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -297,7 +311,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_VEC; } } + return BEST_FATTN_KERNEL_MMA_F16; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } return BEST_FATTN_KERNEL_MMA_F16; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fill.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 00000000..739062c4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,37 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template +static __global__ void fill_kernel(T * dst, const int64_t k, const T value) { + const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<>>((float *)dst_d, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<>>((half *)dst_d, k, ggml_cuda_cast(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fill.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 00000000..8443c836 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index 16c166a0..d69d6219 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -20,6 +20,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" +#include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" @@ -54,6 +55,9 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/fill.cuh" #include "ggml.h" #include @@ -2772,6 +2776,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: break; + case GGML_OP_DIAG: + ggml_cuda_op_diag(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_cuda_op_diag_mask_inf(ctx, dst); break; @@ -2835,6 +2842,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -2856,6 +2869,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; default: return false; } @@ -3383,9 +3399,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } if (should_launch_concurrent_events) { - //Restore the original graph to enable fusion within the streams - cgraph->nodes = const_cast(stream_ctx.original_nodes.data()); - cgraph->n_nodes = (int) stream_ctx.original_nodes.size(); + // Restore original node order within each concurrent region to enable fusion within streams + + std::unordered_map node_to_idx; + node_to_idx.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + node_to_idx[cgraph->nodes[i]] = i; + } + + for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + // Find positions of all nodes from this event in the current graph + std::vector positions; + positions.reserve(event.original_order.size()); + + bool all_found = true; + for (const ggml_tensor * orig_node : event.original_order) { + auto it = node_to_idx.find(orig_node); + if (it != node_to_idx.end()) { + positions.push_back(it->second); + } else { + all_found = false; + break; + } + } + + if (!all_found || positions.size() != event.original_order.size()) { + continue; + } + + // Sort positions to get contiguous range + std::vector sorted_positions = positions; + std::sort(sorted_positions.begin(), sorted_positions.end()); + + bool is_contiguous = true; + for (size_t i = 1; i < sorted_positions.size(); ++i) { + if (sorted_positions[i] != sorted_positions[i-1] + 1) { + is_contiguous = false; + break; + } + } + + if (!is_contiguous) { + continue; + } + + // Restore original order at the sorted positions + int start_pos = sorted_positions[0]; + for (size_t i = 0; i < event.original_order.size(); ++i) { + cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); + } + } } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3419,7 +3482,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); } } - prev_i = i; #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; @@ -3427,6 +3489,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } #endif + prev_i = i; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -4026,14 +4089,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph // store {fork_idx, join_idx} std::vector> concurrent_node_ranges; - // save the original nodes - std::vector original_nodes; - original_nodes.reserve(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; ++i) { - original_nodes.push_back(cgraph->nodes[i]); - } - cuda_ctx->stream_context().original_nodes = std::move(original_nodes); - for (const auto & [root_node, count] : fan_out) { if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; @@ -4138,6 +4193,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph continue; } + // Save the original order of nodes in this region before interleaving + // This is used later to restore grouping for fusion within streams + concurrent_event.original_order.reserve(total_branch_nodes); + for (int i = fork_node_idx + 1; i < join_node_idx; ++i) { + concurrent_event.original_order.push_back(cgraph->nodes[i]); + } + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); concurrent_events.emplace(root_node, std::move(concurrent_event)); @@ -4841,6 +4903,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_FILL: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: + case GGML_OP_DIAG: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh index 0ed42e87..0b13293d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh @@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { namespace ggml_cuda_mma { + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + }; + // Implemented mma combinations are: + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + + template + struct tile {}; + template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -131,9 +152,9 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); #else - return (l & 2) | (threadIdx.x & ~2); + return (l & 2) + (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { NO_DEVICE_CODE; @@ -143,7 +164,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 32 && J == 8) { - return (threadIdx.x & 2) | (l & (4 + 1)); + return (threadIdx.x & 2) + (l & (4 + 1)); } else { NO_DEVICE_CODE; return -1; @@ -152,6 +173,9 @@ namespace ggml_cuda_mma { #elif defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) static constexpr int ne = I * J / 32; +#elif defined(RDNA3) + static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; +#endif // defined(RDNA4) T x[ne] = {0}; static constexpr __device__ bool supported() { @@ -161,7 +185,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 16) { +#if defined(RDNA4) return 8 * (threadIdx.x / 16) + l; +#elif defined(RDNA3) + return 2 * l + (threadIdx.x / 16); +#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; @@ -176,7 +204,6 @@ namespace ggml_cuda_mma { return -1; } } -#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -196,9 +223,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 8) | (threadIdx.x / 4); + return ((l / 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -211,11 +238,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 8) { - return ((threadIdx.x % 4) * 2) | (l % 2); + return ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 16 && J == 16) { - return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -227,26 +254,24 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 8 && J == 8) return true; - if (I == 32 && J == 8) return true; + if (I == 32 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && J == 8) { - return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); - } else if constexpr (I == 32 && J == 8) { + if constexpr (I == 32 && J == 4) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); #else return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM @@ -257,7 +282,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr ((I == 8 || I == 32) && J == 8) { + if constexpr (I == 32 && J == 4) { return l; } else { NO_DEVICE_CODE; @@ -265,6 +290,7 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; half2 x[ne] = {{0.0f, 0.0f}}; @@ -307,11 +333,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { - return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -320,13 +346,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else if constexpr (I == 32 && J == 8) { - return ((l & 2) * 2) | (threadIdx.x % 4); + return ((l & 2) * 2) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -336,14 +362,15 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; +#if defined(AMD_WMMA_AVAILABLE) static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -367,9 +394,6 @@ namespace ggml_cuda_mma { } } #else - static constexpr int ne = I * J / WARP_SIZE; - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; - static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -381,9 +405,9 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -392,11 +416,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -405,6 +429,73 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -422,9 +513,26 @@ namespace ggml_cuda_mma { return ret; } +#else // Volta + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -443,18 +551,34 @@ namespace ggml_cuda_mma { } else if constexpr (std::is_same_v) { if constexpr (I == 16 && J == 4) { int64_t * xi = (int64_t *) t.x; +#if defined(RDNA4) const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; - - }else if constexpr (I == 16 && J == 8) { +#elif defined(RDNA3) + static_assert(tile::ne >= 4, "fragment too small"); + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); + xi[0] = xs[0]; + xi[1] = xs[1]; +#endif // defined(RDNA4) + } else if constexpr (I == 16 && J == 8) { int64_t * xi = (int64_t *) t.x; +#if defined(RDNA4) const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); xi[0] = xs[0]; const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); xi[1] = xs1[0]; - - }else{ +#elif defined(RDNA3) + static_assert(tile::ne >= 8, "fragment too small"); + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); + // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment + xi[0] = xs[0]; + xi[1] = xs[1]; + const int64_t * xs1 = xs + 2; + xi[2] = xs1[0]; + xi[3] = xs1[1]; +#endif // defined(RDNA4) + } else { NO_DEVICE_CODE; } } else { @@ -511,18 +635,6 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - GGML_UNUSED_VARS(t, xs0, stride); - NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#endif // TURING_MMA_AVAILABLE - } - - template - static __device__ __forceinline__ void load_ldmatrix( - tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if 1 // TODO: more generic handling @@ -533,9 +645,31 @@ namespace ggml_cuda_mma { load_generic(t, xs0, stride); #endif // 1 #else - tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; - load_ldmatrix(t16[0], xs0 + 0*stride, stride); - load_ldmatrix(t16[1], xs0 + 16*stride, stride); + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } @@ -747,12 +881,14 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -762,12 +898,14 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { #if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -796,14 +934,14 @@ namespace ggml_cuda_mma { #endif // defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) - using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; - int32x2_t * a_vec = (int32x2_t *) A.x; - int32x2_t * b_vec = (int32x2_t *) B.x; using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( true, @@ -822,7 +960,30 @@ namespace ggml_cuda_mma { acc[0], true ); -#endif // defined(RDNA4) + +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); @@ -860,14 +1021,14 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void mma( tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { - tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; - tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -880,46 +1041,69 @@ namespace ggml_cuda_mma { "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile <16, 8, float> * D16 = reinterpret_cast *>(&D); - const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); - mma(D16[0], A16[0], B); - mma(D16[1], A16[1], B); -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { #if defined(AMD_WMMA_AVAILABLE) - using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; - int32x2_t * a_vec = (int32x2_t *) A.x; - int32x2_t * b_vec = (int32x2_t *) B.x; + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; +#if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; - int32x8_t * acc = (int32x8_t *) D.x; + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#endif // RDNA4 #else GGML_UNUSED(D); GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif +#endif // AMD_WMMA_AVAILABLE } } - diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu index be2ad1c6..7cf33f0d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); default: return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh index c2a0a2e4..e1c695c5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh @@ -37,23 +37,19 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } @@ -248,6 +244,9 @@ static __global__ void mul_mat_f( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu index 03ceba87..f7a2cbca 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu @@ -307,10 +307,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { } if (amd_wmma_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA4(cc)) { - return true; - } + return true; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh index 82468b38..1298f99f 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh @@ -1542,8 +1542,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cm; if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; +#pragma unroll + for (int l = 0; l < tile_A::ne; ++l) { + A1.x[l] = 0x01010101; + } mma(Cm, A1, B); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu index 29aef33c..660c192e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu @@ -1,9 +1,17 @@ #include "pad.cuh" +#include + +__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { + // + size ensures negatives are handled properly + return (coord + size) % size; +} + static __global__ void pad_f32(const float * src, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, - const int ne0, const int ne1, const int ne2, const int ne3) { + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular) { // blockIdx.z: i3*ne2+i2 // blockIdx.y: i1 // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE @@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst, int i1 = blockIdx.y; int i2 = blockIdx.z % ne2; int i3 = blockIdx.z / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - // operation - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - if ((i0 >= lp0 && i0 < ne0 - rp0) && - (i1 >= lp1 && i1 < ne1 - rp1) && - (i2 >= lp2 && i2 < ne2 - rp2) && - (i3 >= lp3 && i3 < ne3 - rp3)) { - const int64_t i00 = i0 - lp0; - const int64_t i01 = i1 - lp1; - const int64_t i02 = i2 - lp2; - const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; + const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0; - const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; + if (!circular) { + if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = 0.0f; + } + } + // circular means on a torus, so x and y wrap around + else { + const int64_t ne00 = ne0 - lp0 - rp0; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne03 = ne3 - lp3 - rp3; + + const int64_t i00 = wrap_around(i0 - lp0, ne00); + const int64_t i01 = wrap_around(i1 - lp1, ne01); + const int64_t i02 = wrap_around(i2 - lp2, ne02); + const int64_t i03 = wrap_around(i3 - lp3, ne03); + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; dst[dst_idx] = src[src_idx]; - } else { - dst[dst_idx] = 0.0f; } } + static void pad_f32_cuda(const float * src, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, - const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2*ne3); - pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2 * ne3); + pad_f32<<>>(src, dst, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, circular); } void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; - const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; - const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; - const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; - const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *) (dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *) (dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *) (dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *) (dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; + const int32_t circular = ((const int32_t *) (dst->op_params))[8]; pad_f32_cuda(src0_d, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (bool) circular, stream); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu b/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu index 2e2b3972..e161d4dc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu @@ -3,7 +3,6 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 -#define MAX_K_FAST 32 // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int i = 0; i < n * n; i += k * WARP_SIZE) { - int i0 = i + offset; + const int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } } - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; - -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; - } - } - __syncthreads(); + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + #pragma unroll - for (int row = 0; row < n; ++row) { + for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; - - { - int j = lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + if (lane < row) { + sum += sA[row * n + lane] * x_low; } - if (row >= WARP_SIZE) { - int j = WARP_SIZE + lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } - } - sum = warp_reduce_sum(sum); - if (lane == 0) { - const float b_val = sXt[col_idx * n + row]; - const float a_diag = sA[row * n + row]; - // no safeguards for division by zero because that indicates corrupt - // data anyway - sXt[col_idx * n + row] = (b_val - sum) / a_diag; + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; } } - __syncthreads(); +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } #pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; } } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/tri.cu b/ml/backend/ggml/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 00000000..44156b63 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,136 @@ +#include "common.cuh" +#include "convert.cuh" +#include "tri.cuh" +#include "ggml.h" + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; + + GGML_UNUSED_VARS(nb00, nb0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; + + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast(0.0f); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast(0.0f); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); + + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { // only 0 and 1 supported + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/tri.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 00000000..a4cc6675 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu b/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu index 687c6693..6bdf3cd9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/upscale.cu @@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, dst[index] = result; } +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + namespace bicubic_interpolation { // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) @@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, const int ne00_src, const int ne01_src, const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, const float sf0, const float sf1, const float sf2, const float sf3, - const float pixel_offset, cudaStream_t stream) { + const float pixel_offset, bool antialias, cudaStream_t stream) { const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + if (antialias) { + upscale_f32_bilinear_antialias<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } } static void upscale_f32_bicubic_cuda(const float * x, float * dst, @@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (mode == GGML_SCALE_MODE_NEAREST) { upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - sf0, sf1, sf2, sf3, pixel_offset, stream); + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); } else if (mode == GGML_SCALE_MODE_BICUBIC) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m index e6664628..42a35736 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m @@ -24,9 +24,6 @@ struct ggml_metal_command_buffer { }; struct ggml_metal { - id device; - id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND] - ggml_metal_device_t dev; ggml_metal_library_t lib; @@ -91,15 +88,15 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { // init context ggml_metal_t res = calloc(1, sizeof(struct ggml_metal)); - res->device = ggml_metal_device_get_obj(dev); + id device = ggml_metal_device_get_obj(dev); - GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]); + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); // TODO: would it be better to have one queue for the backend and one queue for the device? // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] - res->queue = ggml_metal_device_get_queue(dev); - if (res->queue == nil) { + id queue = ggml_metal_device_get_queue(dev); + if (queue == nil) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); return NULL; } @@ -274,7 +271,8 @@ static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_te void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @autoreleasepool { // wrap the source data into a Metal buffer - id buf_src = [ctx->device newBufferWithBytes:data + id device = ggml_metal_device_get_obj(ctx->dev); + id buf_src = [device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared]; @@ -289,7 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, // queue the copy operation into the queue of the Metal context // this will be queued at the end, after any currently ongoing GPU operations - id cmd_buf = [ctx->queue commandBuffer]; + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; id encoder = [cmd_buf blitCommandEncoder]; [encoder copyFromBuffer:buf_src @@ -315,7 +314,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { @autoreleasepool { - id buf_dst = [ctx->device newBufferWithBytesNoCopy:data + id device = ggml_metal_device_get_obj(ctx->dev); + id buf_dst = [device newBufferWithBytesNoCopy:data length:size options:MTLResourceStorageModeShared deallocator:nil]; @@ -331,7 +331,8 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te // queue the copy operation into the queue of the Metal context // this will be queued at the end, after any currently ongoing GPU operations - id cmd_buf = [ctx->queue commandBuffer]; + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; id encoder = [cmd_buf blitCommandEncoder]; [encoder copyFromBuffer:bid_src.metal @@ -362,6 +363,9 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * // number of threads in addition to the main thread const int n_cb = ctx->n_cb; + // keep the memory wired + ggml_metal_device_rsets_keep_alive(ctx->dev); + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes @@ -389,7 +393,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * if (!ctx->capture_started) { // create capture scope - ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; + id device = ggml_metal_device_get_obj(ctx->dev); + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; descriptor.captureObject = ctx->capture_scope; @@ -406,10 +411,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * } } + // short-hand + id queue = ggml_metal_device_get_queue(ctx->dev); + // the main thread commits the first few commands immediately // cmd_buf[n_cb] { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; [cmd_buf retain]; if (ctx->cmd_bufs[n_cb].obj) { @@ -428,7 +436,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * // prepare the rest of the command buffers asynchronously (optional) // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; [cmd_buf retain]; if (ctx->cmd_bufs[cb_idx].obj) { @@ -589,9 +597,11 @@ void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_c } bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { - GGML_ASSERT(ctx->device != nil); + GGML_ASSERT(ctx->dev != nil); - return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; + id device = ggml_metal_device_get_obj(ctx->dev); + + return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } void ggml_metal_capture_next_compute(ggml_metal_t ctx) { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp index 329500a0..680904d1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -50,14 +50,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg } ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { - if (ppls->data.find(name) == ppls->data.end()) { + if (ppls->data.find(name) == ppls->data.end()) { return nullptr; } return ppls->data[name]; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { char base[256]; char name[256]; @@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t snprintf(base, 256, "kernel_%s", op_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); char base[256]; @@ -187,6 +175,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t const char * op_str = "undefined"; switch (op->op) { case GGML_OP_SCALE: op_str = "scale"; break; + case GGML_OP_FILL: op_str = "fill"; break; case GGML_OP_CLAMP: op_str = "clamp"; break; case GGML_OP_SQR: op_str = "sqr"; break; case GGML_OP_SQRT: op_str = "sqrt"; break; @@ -211,6 +200,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; + case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; + case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); @@ -224,17 +215,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); char base[256]; @@ -258,17 +247,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_SUM); char base[256]; @@ -277,17 +264,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t l snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); char base[256]; @@ -306,19 +291,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -327,17 +310,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -346,17 +327,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_TRI); + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + const int ttype = op->op_params[0]; + + snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); char base[256]; @@ -373,19 +374,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); @@ -404,17 +403,47 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + const char * suffix = ""; + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); + snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); char base[256]; @@ -425,19 +454,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); + // Shared memory layout: + // - sgptg * NW floats for partial sums (nsg * 32) + // - sgptg floats for shared_x_dt (nsg) + // - sgptg floats for shared_dA (nsg) + // Total: nsg * (32 + 2) floats + res.smem = (32 + 2)*sizeof(float)*nsg; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -467,41 +499,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -514,27 +542,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); + res.smem = bc_out ? 8192 : 4096 + 2048; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -689,49 +715,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); snprintf(name, 256, "%s_ne02=%d", base, ne02); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); - - ggml_metal_pipeline_set_smem(res, smem); + res.smem = (size_t) ne02*ne20*sizeof(uint16_t); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -743,25 +763,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d", base, bc_inp); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_smem(res, 8192); + res.smem = 8192; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -909,28 +927,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -941,19 +957,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_ snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + res.smem = 32*(sizeof(float) + sizeof(int32_t)); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -971,17 +985,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -999,18 +1011,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } // note: reuse the argsort kernel for top_k -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1029,17 +1039,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1057,17 +1065,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_lib snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -1086,33 +1092,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( has_mask, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, @@ -1131,33 +1135,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( nqptg, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); - ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1198,33 +1200,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ns20, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); - - ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1262,32 +1262,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ns20, nsg, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const ggml_tensor * op, int32_t dv, @@ -1300,26 +1298,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; GGML_UNUSED(op); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_library_t lib, ggml_op op, int32_t n_fuse, @@ -1344,17 +1340,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); @@ -1366,19 +1360,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library snprintf(base, 256, "kernel_l2_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1389,19 +1381,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr snprintf(base, 256, "kernel_group_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); @@ -1434,19 +1424,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ROPE); char base[256]; @@ -1473,23 +1461,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -1502,17 +1488,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1527,17 +1511,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1552,17 +1534,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1576,17 +1556,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); char base[256]; @@ -1595,17 +1573,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); char base[256]; @@ -1614,8 +1590,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (res.pipeline) { return res; } @@ -1624,7 +1600,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD_REFLECT_1D); char base[256]; @@ -1633,17 +1609,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_ snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARANGE); char base[256]; @@ -1652,17 +1626,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_ snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); char base[256]; @@ -1671,17 +1643,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_ADAMW); char base[256]; @@ -1690,17 +1660,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_SGD); char base[256]; @@ -1709,12 +1677,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h index 3976e622..0a8b9211 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h @@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; ggml_metal_pipeline_t ggml_metal_pipeline_init(void); void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); - // a collection of pipelines typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; @@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); +struct ggml_metal_pipeline_with_params { + ggml_metal_pipeline_t pipeline; + + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); + // // MTLCommandBuffer wrapper // @@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline); void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); @@ -100,66 +99,68 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -169,7 +170,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_kvpad, int32_t nsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -180,12 +181,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( int32_t nsg, int32_t nwg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t dv, int32_t nwg); +// MTLResidencySet wrapper + +typedef void * ggml_metal_rset_t; + +// a collection of residency sets (non-owning) +typedef struct ggml_metal_rsets * ggml_metal_rsets_t; + +ggml_metal_rsets_t ggml_metal_rsets_init(void); +void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); + // // device // @@ -219,6 +230,11 @@ void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id @@ -75,14 +74,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { struct ggml_metal_pipeline { id obj; - - // suggested dispatch sizes - int nsg; - - int nr0; - int nr1; - - size_t smem; }; ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { @@ -90,10 +81,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { *res = (struct ggml_metal_pipeline) { /*.obj =*/ nil, - /*.nsg =*/ 0, - /*.nr0 =*/ 0, - /*.nr1 =*/ 0, - /*.smem =*/ 0, }; return res; @@ -105,40 +92,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { free(pipeline); } -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { - pipeline->nsg = nsg; -} - -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { - return pipeline->nsg; -} - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { - pipeline->nr0 = nr0; -} - -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { - return pipeline->nr0; -} - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { - pipeline->nr1 = nr1; -} - -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { - return pipeline->nr1; -} - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { - pipeline->smem = smem; -} - -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { - return pipeline->smem; -} - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { - return pipeline->obj.maxTotalThreadsPerThreadgroup; +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) { + return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup; } struct ggml_metal_library { @@ -146,6 +101,8 @@ struct ggml_metal_library { id device; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines + + NSLock * lock; }; ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { @@ -296,9 +253,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); - res->obj = library; - res->device = device; + res->obj = library; + res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -365,6 +323,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev res->obj = library; res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -380,26 +339,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { ggml_metal_pipelines_free(lib->pipelines); + [lib->lock release]; + free(lib); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { - return ggml_metal_pipelines_get(lib->pipelines, name); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { + [lib->lock lock]; + + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + + [lib->lock unlock]; + + return res; } -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { - // note: the pipelines are cached in the library per device, so they are shared across all metal contexts - ggml_critical_section_start(); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - ggml_critical_section_end(); + [lib->lock lock]; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + if (res.pipeline) { + [lib->lock unlock]; return res; } - res = ggml_metal_pipeline_init(); - @autoreleasepool { NSError * error = nil; @@ -414,36 +394,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; } if (!mtl_function) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); } - return nil; + return res; } - res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, - (int) res->obj.maxTotalThreadsPerThreadgroup, - (int) res->obj.threadExecutionWidth); + if (!obj) { + [lib->lock unlock]; - if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { - ggml_critical_section_end(); + GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); + } + + return res; + } + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, + (void *) obj, + (int) obj.maxTotalThreadsPerThreadgroup, + (int) obj.threadExecutionWidth); + + if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) { + [obj release]; + + [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); - return nil; + return res; } - ggml_metal_pipelines_add(lib->pipelines, name, res); + res.pipeline = ggml_metal_pipeline_init(); + res.pipeline->obj = obj; + + ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline); } - ggml_critical_section_end(); + [lib->lock unlock]; return res; } @@ -485,8 +482,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { [encoder->obj popDebugGroup]; } -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { - [encoder->obj setComputePipelineState:pipeline->obj]; +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) { + [encoder->obj setComputePipelineState:pipeline.pipeline->obj]; } void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { @@ -521,11 +518,106 @@ struct ggml_metal_device { // ref: https://github.com/ggml-org/llama.cpp/pull/15906 id mtl_queue; + ggml_metal_rsets_t rsets; + ggml_metal_library_t library; struct ggml_metal_device_props props; }; +// +// MTLResidenceSet wrapper +// + +struct ggml_metal_rsets { + NSLock * lock; + + NSMutableArray * data; + + // number of seconds since the last graph computation + // keep the residency sets wired for that amount of time to avoid being collected by the OS + int keep_alive_s; + + // background heartbeat thread to keep the residency sets alive + atomic_bool d_stop; + atomic_int d_loop; + + dispatch_group_t d_group; +}; + +ggml_metal_rsets_t ggml_metal_rsets_init(void) { + ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets)); + + res->lock = [[NSLock alloc] init]; + res->data = [[NSMutableArray alloc] init]; + + // by default keep the memory wired for 3 minutes + res->keep_alive_s = 3*60; + + const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv("GGML_METAL_RESIDENCY_KEEP_ALIVE_S"); + if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) { + res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S); + } + + if (res->keep_alive_s <= 0) { + res->keep_alive_s = 3*60; + } + + GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); + + atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); + atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); + + res->d_group = dispatch_group_create(); + + // start a background thread that periodically requests residency for all the currently active sets in the collection + // the requests stop after a certain amount of time (keep_alive_s) of inactivity + dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0); + dispatch_group_async(res->d_group, d_queue, ^{ +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) { + if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) { + [res->lock lock]; + + for (int i = 0; i < (int) res->data.count; ++i) { + [res->data[i] requestResidency]; + } + + atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed); + + [res->lock unlock]; + } + + // half a second + usleep(500 * 1000); + } + } +#endif + }); + + return res; +} + +void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { + if (rsets == NULL) { + return; + } + + // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting + GGML_ASSERT([rsets->data count] == 0); + + atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed); + + dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER); + dispatch_release(rsets->d_group); + + [rsets->data release]; + [rsets->lock release]; + + free(rsets); +} + ggml_metal_device_t ggml_metal_device_init(void) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); @@ -611,8 +703,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } @@ -661,8 +753,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } @@ -694,7 +786,11 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_ERROR("%s: error: failed to create library\n", __func__); } - // -------------------------------------------------- + if (dev->props.use_residency_sets) { + dev->rsets = ggml_metal_rsets_init(); + } else { + dev->rsets = nil; + } // print MTL GPU family: GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); @@ -747,6 +843,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { void ggml_metal_device_free(ggml_metal_device_t dev) { assert(dev != NULL); + ggml_metal_rsets_free(dev->rsets); + ggml_metal_library_free(dev->library); dev->library = NULL; @@ -775,6 +873,42 @@ ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) { return dev->library; } +void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) { + if (rset == nil) { + return; + } + + GGML_ASSERT(dev->rsets); + + [dev->rsets->lock lock]; + + [dev->rsets->data addObject:rset]; + + [dev->rsets->lock unlock]; +} + +void ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) { + if (rset == nil) { + return; + } + + GGML_ASSERT(dev->rsets); + + [dev->rsets->lock lock]; + + [dev->rsets->data removeObject:rset]; + + [dev->rsets->lock unlock]; +} + +void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { + if (dev->rsets == NULL) { + return; + } + + atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); +} + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { if (@available(macOS 10.12, iOS 16.0, *)) { *total = dev->mtl_device.recommendedMaxWorkingSetSize; @@ -820,6 +954,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -852,6 +988,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: + case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -869,6 +1006,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + case GGML_OP_TRI: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_CUMSUM: case GGML_OP_MEAN: @@ -894,10 +1033,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_1D: return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: + // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985 + if (ggml_get_op_params_i32(op, 8) != 0) { + return false; + } + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: @@ -912,6 +1056,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te // for new head sizes, add checks here if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && + op->src[0]->ne[0] != 48 && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 72 && op->src[0]->ne[0] != 80 && @@ -1062,9 +1207,8 @@ struct ggml_metal_buffer { // note: cannot use explicity "id" here because it is not available on certain OSes id rset; - // pointers to global device objects - id device; - id queue; + // pointers to global device + ggml_metal_device_t dev; }; static void ggml_metal_log_allocated_size(id device, size_t size_aligned) { @@ -1107,7 +1251,7 @@ static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) { desc.initialCapacity = buf->n_buffers; NSError * error; - buf->rset = [buf->device newResidencySetWithDescriptor:desc error:&error]; + buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error]; if (error) { GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); [desc release]; @@ -1168,6 +1312,8 @@ static void * ggml_metal_host_malloc(size_t n) { ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) { ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + res->dev = dev; + const size_t size_page = sysconf(_SC_PAGESIZE); size_t size_aligned = size; @@ -1192,9 +1338,6 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->owned = true; - res->device = ggml_metal_device_get_obj(dev); - res->queue = ggml_metal_device_get_queue(dev); - res->n_buffers = 1; if (res->all_data != NULL) { @@ -1203,12 +1346,12 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, if (size_aligned > 0) { if (props_dev->use_shared_buffers && shared) { - res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data + res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; } else { - res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; + res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; } } @@ -1229,6 +1372,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, return NULL; } + ggml_metal_device_rsets_add(dev, res->rset); + //ggml_metal_log_allocated_size(device, size_aligned); return res; @@ -1237,6 +1382,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) { ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + res->dev = dev; + res->all_data = ptr; res->all_size = size; @@ -1259,9 +1406,6 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s size_aligned += (size_page - (size_aligned % size_page)); } - res->device = ggml_metal_device_get_obj(dev); - res->queue = ggml_metal_device_get_queue(dev); - const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); // the buffer fits into the max buffer size allowed by the device @@ -1271,7 +1415,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s res->buffers[res->n_buffers].metal = nil; if (size_aligned > 0) { - res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (res->buffers[res->n_buffers].metal == nil) { GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); @@ -1280,7 +1424,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s } } - ggml_metal_log_allocated_size(res->device, size_aligned); + ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned); ++res->n_buffers; } else { @@ -1298,7 +1442,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s res->buffers[res->n_buffers].metal = nil; if (size_step_aligned > 0) { - res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (res->buffers[res->n_buffers].metal == nil) { GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); @@ -1307,7 +1451,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s } } - ggml_metal_log_allocated_size(res->device, size_step_aligned); + ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned); if (i + size_step < size) { GGML_LOG_INFO("\n"); @@ -1325,10 +1469,14 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s return NULL; } + ggml_metal_device_rsets_add(dev, res->rset); + return res; } void ggml_metal_buffer_free(ggml_metal_buffer_t buf) { + ggml_metal_device_rsets_rm(buf->dev, buf->rset); + for (int i = 0; i < buf->n_buffers; i++) { [buf->buffers[i].metal release]; } @@ -1365,8 +1513,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); bid_dst.offs += offset; - id queue = buf->queue; - id cmd_buf = [queue commandBufferWithUnretainedReferences]; + id cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; { id encoder = [cmd_buf blitCommandEncoder]; @@ -1392,7 +1539,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * @autoreleasepool { // src void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data - id buf_src = [buf->device newBufferWithBytesNoCopy:data_ptr + id buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr length:size options:MTLResourceStorageModeShared deallocator:nil]; @@ -1407,8 +1554,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); - id queue = buf->queue; - id cmd_buf = [queue commandBufferWithUnretainedReferences]; + id cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; { id encoder = [cmd_buf blitCommandEncoder]; @@ -1450,15 +1596,14 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten bid_src.offs += offset; // dst - id buf_dst = [buf->device newBufferWithBytesNoCopy:data + id buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data length:size options:MTLResourceStorageModeShared deallocator:nil]; GGML_ASSERT(buf_dst); - id queue = buf->queue; - id cmd_buf = [queue commandBufferWithUnretainedReferences]; + id cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; { id encoder = [cmd_buf blitCommandEncoder]; @@ -1484,8 +1629,7 @@ void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { } @autoreleasepool { - id queue = buf->queue; - id cmd_buf = [queue commandBufferWithUnretainedReferences]; + id cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; { id encoder = [cmd_buf blitCommandEncoder]; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 9903af36..13c6715b 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -1962,6 +1962,7 @@ GGML_TABLE_END() #define FC_MUL_MV 600 #define FC_MUL_MM 700 #define FC_ROPE 800 +#define FC_SSM_CONV 900 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 @@ -2067,6 +2068,10 @@ typedef struct { float bias; } ggml_metal_kargs_scale; +typedef struct { + float val; +} ggml_metal_kargs_fill; + typedef struct { float min; float max; @@ -2716,6 +2721,25 @@ typedef struct { float slope; } ggml_metal_kargs_leaky_relu; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; @@ -4026,6 +4050,22 @@ kernel void kernel_scale_f32_4( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_fill_f32( + constant ggml_metal_kargs_fill & args, + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +kernel void kernel_fill_f32_4( + constant ggml_metal_kargs_fill & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + kernel void kernel_clamp_f32( constant ggml_metal_kargs_clamp & args, device const float * src0, @@ -4372,6 +4412,36 @@ kernel void kernel_exp_f32_4( dst[tpig] = exp(src0[tpig]); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_expm1_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -4720,6 +4790,75 @@ typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; + +template +bool _ggml_vec_tri_cmp(const int i, const int r); + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i < r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i <= r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i > r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i >= r; +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + // Use the comparison as a mask for branchless + dst_row[i0] = static_cast(_ggml_vec_tri_cmp(i0, i1)) * src_row[i0]; + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, @@ -5005,7 +5144,102 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]]; + +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_batched_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -5025,7 +5259,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -5065,32 +5307,47 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); s = (s0 * dA) + (B[i0] * x_dt); const float sumf = simd_sum(s * C[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } // recurse s0 = s; - x += args.ns12; - dt += args.ns21; B += args.ns42; C += args.ns52; } + // Advance pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; @@ -8749,6 +9006,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8762,6 +9020,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8776,6 +9035,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8790,6 +9050,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8803,6 +9064,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8816,6 +9078,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8829,6 +9092,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -8842,6 +9106,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index 342dc4f8..8944b07e 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -77,6 +77,7 @@ #define FC_MUL_MV 600 #define FC_MUL_MM 700 #define FC_ROPE 800 +#define FC_SSM_CONV 900 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 @@ -182,6 +183,10 @@ typedef struct { float bias; } ggml_metal_kargs_scale; +typedef struct { + float val; +} ggml_metal_kargs_fill; + typedef struct { float min; float max; @@ -831,6 +836,25 @@ typedef struct { float slope; } ggml_metal_kargs_leaky_relu; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9871e976..e99c1763 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -221,7 +221,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { } if (ctx->debug_graph > 0) { - GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : ""); + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : ""); } if (ctx->debug_graph > 1) { GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); @@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_scale(ctx, idx); } break; + case GGML_OP_FILL: + { + n_fuse = ggml_metal_op_fill(ctx, idx); + } break; case GGML_OP_CLAMP: { n_fuse = ggml_metal_op_clamp(ctx, idx); @@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_FLASH_ATTN_EXT: { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); @@ -524,7 +532,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { /*.dim =*/ dim, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -550,7 +558,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); ggml_metal_kargs_repeat args = { /*.ne00 =*/ ne00, @@ -616,7 +624,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { // TODO: make a simpler cpy_bytes kernel //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { /*.nk0 =*/ ne00, @@ -679,7 +687,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -721,7 +729,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const float val = ggml_get_op_params_f32(op, 0); + + ggml_metal_kargs_fill args = { + /*.val =*/ val + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -760,7 +803,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -789,7 +832,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); @@ -817,7 +860,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); const int32_t swp = ggml_get_op_params_i32(op, 1); const float alpha = ggml_get_op_params_f32(op, 2); @@ -870,7 +913,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { /*.np =*/ n, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); int nth = 32; // SIMD width @@ -925,7 +968,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); int nth = 32; // SIMD width @@ -936,7 +979,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -963,7 +1006,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); int nth = 1; while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { @@ -1060,7 +1103,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); ggml_metal_kargs_cumsum_add args = { /*.ne00 =*/ ne00, @@ -1106,7 +1149,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, @@ -1151,7 +1194,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); const int32_t nk0 = ne0/ggml_blck_size(op->type); @@ -1252,7 +1295,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { /*.n_head_log2 =*/ n_head_log2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); int nth = 32; // SIMD width @@ -1266,7 +1309,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1322,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + const bool use_batched = (ne1 > 1); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + if (use_batched) { + // Determine the smallest power of 2 that's >= ne1, but <= 256 + int BATCH_SIZE; + if (ne1 > 128) BATCH_SIZE = 256; + else if (ne1 > 64 ) BATCH_SIZE = 128; + else if (ne1 > 32 ) BATCH_SIZE = 64; + else if (ne1 > 16 ) BATCH_SIZE = 32; + else if (ne1 > 8 ) BATCH_SIZE = 16; + else if (ne1 > 4 ) BATCH_SIZE = 8; + else BATCH_SIZE = 2; - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } return 1; } @@ -1409,11 +1480,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -1426,7 +1497,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); - ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); @@ -1449,7 +1520,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { const int64_t C = op->ne[0]; const int64_t H = op->src[0]->ne[1]; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); int ida = 0; @@ -1485,7 +1556,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); @@ -1592,7 +1663,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { /* .np = */ np }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); const int ntg = (np + nth - 1) / nth; @@ -1701,7 +1772,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, @@ -1748,7 +1819,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { // default: break; //} - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); ggml_metal_kargs_mul_mm args = { /*.ne00 =*/ ne00, @@ -1773,18 +1844,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, @@ -1915,9 +1986,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -1938,7 +2009,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); ggml_metal_kargs_mul_mm_id args = { /*.ne00 =*/ ne00, @@ -1967,20 +2038,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_ids, 4); ggml_metal_encoder_set_buffer (enc, bid_dst, 5); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, @@ -2064,7 +2135,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2308,7 +2379,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2339,7 +2410,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/ nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2424,7 +2495,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2476,7 +2547,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2578,7 +2649,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -2630,7 +2701,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { nrows, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2762,7 +2833,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer bid_src1.offs = 0; - ggml_metal_pipeline_t pipeline = nullptr; + struct ggml_metal_pipeline_with_params pipeline; if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -2835,7 +2906,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; @@ -2844,7 +2915,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; const int64_t nrows = ggml_nrows(op->src[0]); @@ -2887,7 +2958,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); int nth = 32; // SIMD width //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { @@ -2897,7 +2968,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); //nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3022,7 +3093,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); int nth = 32; // SIMD width @@ -3033,7 +3104,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, args.ne00_t); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3127,7 +3198,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* src2 =*/ op->src[2] != nullptr, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3199,7 +3270,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { /*.KHW =*/ KH * KW, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -3270,7 +3341,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { /*.d1 =*/ d1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); nth = std::min(nth, 256); @@ -3325,7 +3396,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3377,7 +3448,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3433,7 +3504,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { /*.sf3 =*/ sf3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); @@ -3477,7 +3548,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); const int nth = std::min(1024, ne0); @@ -3523,7 +3594,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { /*.p1 =*/ ((const int32_t *)(op->op_params))[1] }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); const int nth = std::min(1024, ne0); @@ -3560,7 +3631,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { const int nth = std::min(1024, ne0); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3591,7 +3662,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { /*.max_period =*/ max_period, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); const int nth = std::max(1, std::min(1024, dim/2)); @@ -3621,7 +3692,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { /*.nb01 = */ nb01, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); const int64_t nrows = ggml_nrows(op->src[0]); @@ -3630,7 +3701,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { nth *= 2; } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3657,7 +3728,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3706,7 +3777,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); int len = nth; @@ -3764,7 +3835,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3818,7 +3889,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); int len = args.top_k; @@ -3881,7 +3952,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { /*.slope =*/ slope }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); int64_t n = ggml_nelements(op); @@ -3899,6 +3970,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3910,7 +4032,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_adamw args = { @@ -3946,7 +4068,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_sgd args = { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h index b5546146..902b5445 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h @@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); @@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 761b57a2..c98d269d 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_fill_f32( + constant ggml_metal_kargs_fill & args, + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +kernel void kernel_fill_f32_4( + constant ggml_metal_kargs_fill & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + kernel void kernel_clamp_f32( constant ggml_metal_kargs_clamp & args, device const float * src0, @@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4( dst[tpig] = exp(src0[tpig]); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_expm1_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; + +template +bool _ggml_vec_tri_cmp(const int i, const int r); + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i < r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i <= r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i > r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i >= r; +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + // Use the comparison as a mask for branchless + dst_row[i0] = static_cast(_ggml_vec_tri_cmp(i0, i1)) * src_row[i0]; + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, @@ -2228,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]]; + +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_batched_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2248,7 +2458,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -2288,32 +2506,47 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); s = (s0 * dA) + (B[i0] * x_dt); const float sumf = simd_sum(s * C[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } // recurse s0 = s; - x += args.ns12; - dt += args.ns21; B += args.ns42; C += args.ns52; } + // Advance pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; @@ -5972,6 +6205,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5985,6 +6219,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5999,6 +6234,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6013,6 +6249,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6026,6 +6263,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6039,6 +6277,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6052,6 +6291,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6065,6 +6305,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c98f98c7..8a83427f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -355,10 +355,17 @@ enum vk_conv_shapes { CONV_SHAPE_COUNT, }; -uint32_t conv_shapes_wg_denoms[][3] = { - { 128, 128, 1 }, - { 64, 32, 1 }, - { 32, 256, 1 }, +struct vk_conv_block_size { + uint32_t K; + uint32_t NPQ; + uint32_t CRS; +}; + +vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = { + // K NPQ CRS + { 128, 128, 16 }, // CONV_SHAPE_128x128 + { 64, 32, 32 }, // CONV_SHAPE_64x32 + { 32, 256, 16 }, // CONV_SHAPE_32x256 }; enum dmmv_wg_sizes { @@ -521,6 +528,7 @@ struct vk_device_struct { bool fp16; bool bf16; bool pipeline_robustness; + bool memory_priority; vk::Device device; uint32_t vendor_id; vk::DriverId driver_id; @@ -771,11 +779,6 @@ struct vk_device_struct { std::unique_ptr memory_logger; #endif - // for GGML_VK_PERF_LOGGER - std::unique_ptr perf_logger; - vk::QueryPool query_pool; - int32_t num_queries; - ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -1044,6 +1047,7 @@ struct vk_op_pad_push_constants { uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t misalign_offsets; + uint32_t circular; uint32_t lp0; uint32_t rp0; uint32_t lp1; uint32_t rp1; @@ -1086,6 +1090,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor p.rp2 = dst->op_params[5]; p.lp3 = dst->op_params[6]; p.rp3 = dst->op_params[7]; + p.circular = dst->op_params[8]; return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -1229,6 +1234,7 @@ struct vk_op_topk_push_constants { uint32_t orig_ncols; uint32_t ncols_input; uint32_t ncols_output; + uint32_t k; uint32_t nrows; uint32_t first_pass; uint32_t last_pass; @@ -1344,20 +1350,11 @@ struct vk_op_conv2d_push_constants { uint32_t Cin; uint32_t N; - uint32_t KW; - uint32_t KH; uint32_t W; uint32_t H; uint32_t OW; uint32_t OH; - uint32_t s0; - uint32_t s1; - uint32_t p0; - uint32_t p1; - uint32_t d0; - uint32_t d1; - uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -1381,48 +1378,6 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } -struct vk_op_conv_transpose_2d_push_constants { - uint32_t Cout; - uint32_t Cin; - uint32_t N; - - uint32_t KW; - uint32_t KH; - uint32_t W; - uint32_t H; - uint32_t OW; - uint32_t OH; - - uint32_t s0; - uint32_t s1; - uint32_t p0; - uint32_t p1; - uint32_t d0; - uint32_t d1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - - uint32_t nb1; - uint32_t nb2; - uint32_t nb3; - - // init_fastdiv_values constants for dividing by OW, OW*OH - uint32_t OWmp; uint32_t OWL; - uint32_t OWOHmp; uint32_t OWOHL; -}; - -template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { - // Compute magic values to divide by OW, OW*OH - init_fastdiv_values(p.OW, p.OWmp, p.OWL); - init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); -} - struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -1565,12 +1520,21 @@ private: #define VK_LOG_MEMORY(msg) ((void) 0) #endif // GGML_VULKAN_MEMORY_DEBUG +static bool vk_perf_logger_enabled = false; +// number of calls between perf logger prints +static uint32_t vk_perf_logger_frequency = 1; + class vk_perf_logger { public: - void print_timings() { + void print_timings(bool force = false) { if (timings.empty()) { return; } + print_count++; + if ((print_count % vk_perf_logger_frequency) != 0 && !force) { + return; + } + print_count = 0; uint64_t total_all_op_times = 0; std::cerr << "----------------\nVulkan Timings:" << std::endl; for (const auto & t : timings) { @@ -1607,16 +1571,20 @@ class vk_perf_logger { flops.clear(); } - void log_timing(const ggml_tensor * node, uint64_t time) { + void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) { + std::string fusion_str; + if (fusion_name) { + fusion_str = fusion_name + std::string(" "); + } if (node->op == GGML_OP_UNARY) { - timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); return; } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2]; + const uint64_t m = node->ne[0]; + const uint64_t n = node->ne[1]; const uint64_t k = node->src[1]->ne[0]; - const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; + const uint64_t batch = node->ne[2] * node->ne[3]; std::string name = ggml_op_name(node->op); if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) || (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) { @@ -1625,9 +1593,13 @@ class vk_perf_logger { name += " "; name += ggml_type_name(node->src[0]->type); name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + if (node->op == GGML_OP_MUL_MAT_ID) { + name += " n_expert=" + std::to_string(node->src[0]->ne[2]); + } if (batch > 1) { name += " batch=" + std::to_string(batch); } + name = fusion_str + name; timings[name].push_back(time); flops[name].push_back(m * n * (k + (k - 1)) * batch); return; @@ -1649,6 +1621,7 @@ class vk_perf_logger { uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + ", N=N*OW*OH=" + std::to_string(size_N); + name = fusion_str + name; flops[name].push_back(n_flops); timings[name].push_back(time); return; @@ -1656,6 +1629,7 @@ class vk_perf_logger { if (node->op == GGML_OP_RMS_NORM) { std::string name = ggml_op_name(node->op); name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + name = fusion_str + name; timings[name].push_back(time); return; } @@ -1666,6 +1640,7 @@ class vk_perf_logger { const ggml_tensor * v = node->src[2]; const ggml_tensor * m = node->src[3]; std::stringstream name; + name << fusion_str; name << ggml_op_name(node->op) << " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " << " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " << @@ -1675,11 +1650,21 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } - timings[ggml_op_name(node->op)].push_back(time); + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << fusion_str; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } + timings[fusion_str + ggml_op_name(node->op)].push_back(time); } private: std::map> timings; std::map> flops; + uint32_t print_count {}; }; struct ggml_backend_vk_context { @@ -1733,6 +1718,14 @@ struct ggml_backend_vk_context { // Bit 'i' means nodes[start_of_fusion + i] writes to memory. // If there's no fusion, bit 0 is still set. int fused_ops_write_mask {}; + + // for GGML_VK_PERF_LOGGER + std::unique_ptr perf_logger; + vk::QueryPool query_pool; + std::vector query_fusion_names; + std::vector query_nodes; + int32_t num_queries {}; + int32_t query_idx {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -1858,8 +1851,6 @@ struct vk_instance_t { static bool vk_instance_initialized = false; static vk_instance_t vk_instance; -static bool vk_perf_logger_enabled = false; - #ifdef GGML_VULKAN_CHECK_RESULTS static size_t vk_skip_checks; static size_t vk_output_tensor; @@ -2362,7 +2353,13 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; + const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f }; + + vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; + + if (device->memory_priority) { + mem_flags_info.setPNext(&mem_priority_info); + } for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { const auto & req_flags = *it; @@ -3567,7 +3564,7 @@ static void ggml_vk_load_shaders(vk_device& device) { SHADER_REDUCTION_MODE_SHMEM; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3591,7 +3588,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3637,7 +3634,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -4043,7 +4040,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + sizeof(int) * device->subgroup_size + 2 * sizeof(int) + - (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int); if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); @@ -4063,10 +4060,16 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &s : device->pipeline_solve_tri_f32) { const vk_solve_tri_pipeline_state &state = s.first; + + // Max number of rows to load at a time, limited by shared memory + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float)); + // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory + const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K)))); + ggml_vk_create_pipeline( device, s.second, "solve_tri_f32", solve_tri_f32_len, solve_tri_f32_data, "main", 3, - sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true); } #define IM2COL(bda) \ @@ -4112,12 +4115,10 @@ static void ggml_vk_load_shaders(vk_device& device) { // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { uint32_t conv2d_WG_SIZE = 256; - uint32_t conv2d_BS_K = 128; - uint32_t conv2d_BS_CRS = 16; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. - uint32_t conv2d_BS_NPQ = 128; - uint32_t conv2d_TS_K = 8; + uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8; uint32_t conv2d_SHMEM_PAD = 4; + vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s]; bool conv2d_UNROLL = true; #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -4131,29 +4132,9 @@ static void ggml_vk_load_shaders(vk_device& device) { conv2d_UNROLL = false; } else if (device->vendor_id == VK_VENDOR_ID_AMD) { conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; - } - - switch (s) { - default: - case CONV_SHAPE_128x128: - conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0]; - conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1]; - conv2d_BS_CRS = 16; - if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { + if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) { conv2d_UNROLL = false; } - break; - case CONV_SHAPE_64x32: - conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0]; - conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1]; - conv2d_BS_CRS = 32; - conv2d_TS_K = 4; - break; - case CONV_SHAPE_32x256: - conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0]; - conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1]; - conv2d_BS_CRS = 16; - break; } // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. @@ -4167,22 +4148,22 @@ static void ggml_vk_load_shaders(vk_device& device) { allow_collectives_nv && allow_collectives_amd) { use_collectives = 1; - conv2d_BS_CRS = std::min( + conv2d_BS.CRS = std::min( device->subgroup_size, - conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. + conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. } uint32_t conv2d_shmem_req = - (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); + (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float); if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { - conv2d_BS_CRS = 8; + conv2d_BS.CRS = 8; if (use_collectives) { - conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS); } } - std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; - std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + std::array wg_denoms = { conv2d_BS.K, 1, 1 }; + std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; #define CREATE_CONV(name, type_suffix, spv_suffix) \ for (auto &c : device->pipeline_##name##type_suffix[s]) { \ @@ -4199,15 +4180,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline( \ device, c.second, #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ - if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ - CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ - CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ - } + CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ + CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { CREATE_CONVS(_cm2) @@ -4228,9 +4207,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } for (auto &c : compiles) { @@ -4251,9 +4230,6 @@ static vk_device ggml_vk_get_device(size_t idx) { #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); #endif - if (vk_perf_logger_enabled) { - device->perf_logger = std::unique_ptr(new vk_perf_logger()); - } size_t dev_num = vk_instance.device_indices[idx]; @@ -4333,6 +4309,9 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { pipeline_executable_properties_support = true; + } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && + getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { + device->memory_priority = true; } } @@ -4524,6 +4503,16 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_EXT_pipeline_robustness"); } + VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features; + memory_priority_features.pNext = nullptr; + memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT; + memory_priority_features.memoryPriority = VK_FALSE; + if (device->memory_priority) { + last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features; + last_struct = (VkBaseOutStructure *)&memory_priority_features; + device_extensions.push_back("VK_EXT_memory_priority"); + } + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; subgroup_size_control_features.pNext = nullptr; subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; @@ -5103,7 +5092,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } -static bool ggml_vk_instance_validation_ext_available(); +static bool ggml_vk_instance_layer_settings_available(); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); @@ -5132,19 +5121,19 @@ static void ggml_vk_instance_init() { vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); - const bool validation_ext = ggml_vk_instance_validation_ext_available(); + const bool layer_settings = ggml_vk_instance_layer_settings_available(); #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; std::vector layers; - if (validation_ext) { + if (layer_settings) { layers.push_back("VK_LAYER_KHRONOS_validation"); } std::vector extensions; - if (validation_ext) { - extensions.push_back("VK_EXT_validation_features"); + if (layer_settings) { + extensions.push_back("VK_EXT_layer_settings"); } #ifdef __APPLE__ if (portability_enumeration_ext) { @@ -5154,26 +5143,24 @@ static void ggml_vk_instance_init() { if (debug_utils_ext) { extensions.push_back("VK_EXT_debug_utils"); } - vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); + VkBool32 enable_best_practice = layer_settings; + std::vector settings = { + { + "VK_LAYER_KHRONOS_validation", + "validate_best_practices", + vk::LayerSettingTypeEXT::eBool32, + 1, + &enable_best_practice + }, + }; + vk::LayerSettingsCreateInfoEXT layer_setting_info(settings); + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info); #ifdef __APPLE__ if (portability_enumeration_ext) { instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; } #endif - std::vector features_enable; - vk::ValidationFeaturesEXT validation_features; - - if (validation_ext) { - features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; - validation_features = { - features_enable, - {}, - }; - validation_features.setPNext(nullptr); - instance_create_info.setPNext(&validation_features); - GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); - } vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; @@ -5188,6 +5175,11 @@ static void ggml_vk_instance_init() { } vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); + + if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { + vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY); + } // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance); @@ -5365,6 +5357,10 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + if (vk_perf_logger_enabled) { + ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); + } + #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -6921,6 +6917,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: + if (src0_type == GGML_TYPE_Q2_K) { + return true; + } + if (k <= 4096) { return false; } @@ -8253,59 +8253,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } -static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - - // src0 - kernel: [KW, KH, Cin, Cout] - // src1 - input: [W, H, Cin, N] - // dst - result: [OW, OH, Cout, N] - - // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) - auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) { + auto n_tiles = [&](vk_conv_shapes s) { + return CEIL_DIV(K, vk_conv_block_sizes[s].K) + * CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ); }; - // parallelize in {OW/BS_K, OH/BS_NPQ, 1} - int64_t W = src1->ne[0]; - int64_t H = src1->ne[1]; - int64_t KW = src0->ne[0]; - int64_t KH = src0->ne[1]; - int64_t Cout = src0->ne[3]; - int64_t N = src1->ne[3]; - int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); - int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); - int64_t NPQ = N * OW * OH; - // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups - std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; - return elements; -} + // We can't query number of shader cores on Intel, use 32 as a placeholder + // so small convolutions will still choose a smaller tile. + const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; -static std::array ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - - // src0 - kernel: [KW, KH, Cout, Cin] - // src1 - input: [W, H, Cin, N] - // dst - result: [OW, OH, Cout, N] - - auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { - return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; - }; - // parallelize in {OW/BS_K, OH/BS_NPQ, 1} - int64_t W = src1->ne[0]; - int64_t H = src1->ne[1]; - int64_t KW = src0->ne[0]; - int64_t KH = src0->ne[1]; - int64_t Cout = src0->ne[2]; - int64_t N = src1->ne[3]; - int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); - int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); - int64_t NPQ = N * OW * OH; - - // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups - std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; - return elements; + if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { + return CONV_SHAPE_128x128; + } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) { + return CONV_SHAPE_32x256; + } else { + return CONV_SHAPE_64x32; + } } static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { @@ -8768,39 +8732,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_CONV_2D: case GGML_OP_CONV_TRANSPOSE_2D: - if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - std::array elements{}; - if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); - else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); - vk_conv_shapes shape; - - uint32_t tiles[CONV_SHAPE_COUNT]; - for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { - tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]); - } - - // We can't query number of shader cores on Intel, use 32 as a placeholder - // so small convolutions will still choose a smaller tile. - const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; - - if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { - shape = CONV_SHAPE_128x128; - } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { - shape = CONV_SHAPE_32x256; - } else { - shape = CONV_SHAPE_64x32; - } - - uint32_t KW = static_cast(src0->ne[0]); - uint32_t KH = static_cast(src0->ne[1]); - uint32_t s0 = static_cast(dst->op_params[0]); - uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[1]) : static_cast(dst->op_params[0]); - uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[2]) : 0; - uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[3]) : 0; - uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[4]) : 1; - uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast(dst->op_params[5]) : 1; + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + uint32_t K = dst->ne[2]; // Cout + uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW + vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ); + bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D; + uint32_t KW = (uint32_t)src0->ne[0]; + uint32_t KH = (uint32_t)src0->ne[1]; + uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0)); + uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0; + uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0; + uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0; + uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1; + uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1; vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); std::map *pipelines = nullptr; @@ -9119,13 +9064,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { N * OC * OH * OW, 1, 1}; } break; case GGML_OP_CONV_2D: - { - elements = ggml_vk_get_conv_elements(dst); - } break; case GGML_OP_CONV_TRANSPOSE_2D: - { - elements = ggml_vk_get_conv_transpose_2d_elements(dst); - } break; + if constexpr (std::is_same_v) { + const uint32_t NPQ = pc.N * pc.OH * pc.OW; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ); + const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ); + + elements = { pc.Cout, NPQ_blocks, 1 }; + if (elements[1] > 512) { + elements[2] = CEIL_DIV(elements[1], 512); + elements[1] = 512; + } + } else { + GGML_ABORT("invalid push constant type for CONV_2D"); + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -10347,17 +10300,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t nrows = ggml_nrows(src0); uint32_t k = dst->ne[0]; - vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; - // Reserve space for ivec2 per element, double buffered - const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); - const size_t x_sz = dbl_buf_size * 2; - uint32_t dbl_buf_index = 0; - - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - ggml_vk_preallocate_buffers(ctx, subctx); - } if (ctx->prealloc_x_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } @@ -10372,8 +10316,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // largest elements. Repeat until we have the top K elements. // Need to do at least one iteration to write out the results. bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; while (num_elements > k || !done_one_iter) { - done_one_iter = true; // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup @@ -10413,6 +10358,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Number of elements remaining after this pass uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + vk_subbuffer src_buf; vk_subbuffer dst_buf; @@ -10438,6 +10398,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons if (num_elements > k) { ggml_vk_sync_buffers(ctx, subctx); } + done_one_iter = true; } ctx->prealloc_x_need_sync = true; } @@ -10668,30 +10629,24 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float)); + bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D; + vk_op_conv2d_push_constants p{}; - p.Cout = static_cast(ne03); - p.Cin = static_cast(ne02); + p.Cout = static_cast(!transpose ? ne03 : ne02); + p.Cin = static_cast(!transpose ? ne02 : ne03); p.N = static_cast(ne13); + GGML_ASSERT(p.Cout == ne2); + GGML_ASSERT(p.Cin == ne12); - p.KW = static_cast(ne00); - p.KH = static_cast(ne01); p.W = static_cast(ne10); p.H = static_cast(ne11); p.OW = static_cast(ne0); p.OH = static_cast(ne1); - p.s0 = static_cast(dst->op_params[0]); - p.s1 = static_cast(dst->op_params[1]); - p.p0 = static_cast(dst->op_params[2]); - p.p1 = static_cast(dst->op_params[3]); - p.d0 = static_cast(dst->op_params[4]); - p.d1 = static_cast(dst->op_params[5]); - p.nb01 = static_cast(nb01 / nb00); p.nb02 = static_cast(nb02 / nb00); p.nb03 = static_cast(nb03 / nb00); @@ -10704,59 +10659,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, p.nb2 = static_cast(nb2 / nb0); p.nb3 = static_cast(nb3 / nb0); - GGML_ASSERT(ne03 == ne2); - GGML_ASSERT(ne02 == ne12); - - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p)); -} - -static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, - const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - GGML_ASSERT(nb0 == sizeof(float)); - - vk_op_conv_transpose_2d_push_constants p{}; - p.Cout = static_cast(ne02); - p.Cin = static_cast(ne03); - p.N = static_cast(ne13); - - p.KW = static_cast(ne00); - p.KH = static_cast(ne01); - p.W = static_cast(ne10); - p.H = static_cast(ne11); - p.OW = static_cast(ne0); - p.OH = static_cast(ne1); - - p.s0 = static_cast(dst->op_params[0]); - p.s1 = static_cast(dst->op_params[0]); - p.p0 = 0; - p.p1 = 0; - p.d0 = 1; - p.d1 = 1; - - p.nb01 = static_cast(nb01 / nb00); - p.nb02 = static_cast(nb02 / nb00); - p.nb03 = static_cast(nb03 / nb00); - - p.nb11 = static_cast(nb11 / nb10); - p.nb12 = static_cast(nb12 / nb10); - p.nb13 = static_cast(nb13 / nb10); - - p.nb1 = static_cast(nb1 / nb0); - p.nb2 = static_cast(nb2 / nb0); - p.nb3 = static_cast(nb3 / nb0); - - GGML_ASSERT(ne02 == ne2); - GGML_ASSERT(ne03 == ne12); - - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p)); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p)); } static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -12127,11 +12030,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_CONV_2D: - ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node); - - break; case GGML_OP_CONV_TRANSPOSE_2D: - ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node); + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_CONV_2D_DW: @@ -12336,6 +12236,9 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->compute_cmd_pool.destroy(ctx->device->device); ctx->transfer_cmd_pool.destroy(ctx->device->device); + if (vk_perf_logger_enabled) { + ctx->perf_logger->print_timings(true); + } } static int ggml_vk_get_device_count() { @@ -13159,24 +13062,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg vk_context compute_ctx; if (vk_perf_logger_enabled) { // allocate/resize the query pool - if (ctx->device->num_queries < cgraph->n_nodes + 1) { - if (ctx->device->query_pool) { - ctx->device->device.destroyQueryPool(ctx->device->query_pool); + if (ctx->num_queries < cgraph->n_nodes + 1) { + if (ctx->query_pool) { + ctx->device->device.destroyQueryPool(ctx->query_pool); } vk::QueryPoolCreateInfo query_create_info; query_create_info.queryType = vk::QueryType::eTimestamp; query_create_info.queryCount = cgraph->n_nodes + 100; - ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); - ctx->device->num_queries = query_create_info.queryCount; + ctx->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->num_queries = query_create_info.queryCount; + ctx->query_fusion_names.resize(ctx->num_queries); + ctx->query_nodes.resize(ctx->num_queries); } - ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1); + std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr); + std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr); GGML_ASSERT(ctx->compute_ctx.expired()); compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + ctx->query_idx = 0; + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } ctx->prealloc_y_last_pipeline_used = nullptr; @@ -13217,52 +13125,66 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + const char *fusion_string {}; if (!ctx->device->disable_fusion) { uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; + fusion_string = "MULTI_ADD"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 2; + fusion_string = "MUL_MAT_ADD_ADD"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; + fusion_string = "MUL_MAT_ADD"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 2; + fusion_string = "MUL_MAT_ID_ADD_ID_MUL"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + fusion_string = "MUL_MAT_ID_ADD_ID"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + fusion_string = "MUL_MAT_ID_MUL"; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { ctx->num_additional_fused_ops = 4; + fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; + fusion_string = "RMS_NORM_MUL_ROPE"; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + fusion_string = "RMS_NORM_MUL"; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; + fusion_string = "ROPE_VIEW_SET_ROWS"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 1; + fusion_string = "TOPK_MOE_LATE_SOFTMAX"; } } ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; @@ -13276,7 +13198,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); - if (vk_perf_logger_enabled) { + if (vk_perf_logger_enabled && enqueued) { if (ctx->compute_ctx.expired()) { compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; @@ -13284,10 +13206,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } else { compute_ctx = ctx->compute_ctx.lock(); } - // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple - for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); - } + ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; + ctx->query_fusion_names[ctx->query_idx] = fusion_string; + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } if (enqueued) { @@ -13328,14 +13249,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); - VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); - for (int i = 0; i < cgraph->n_nodes; i++) { - if (!ggml_vk_is_empty(cgraph->nodes[i])) { - ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); - } + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 1; i < ctx->query_idx; i++) { + auto node = ctx->query_nodes[i]; + auto name = ctx->query_fusion_names[i]; + ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); } - ctx->device->perf_logger->print_timings(); + ctx->perf_logger->print_timings(); } if (!ctx->device->support_async) { @@ -14235,6 +14156,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_ACC: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: @@ -14283,10 +14205,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm const uint32_t N = op->src[0]->ne[0]; const uint32_t K = op->src[1]->ne[0]; // K dimension limited to workgroup size - if (K > 128) { + if (K > 1u << device->max_workgroup_size_log2) { return false; } - if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float)); + + if (batch_N == 0) { return false; } return true; @@ -14359,13 +14283,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_CONV_2D: case GGML_OP_CONV_TRANSPOSE_2D: { - // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_CONV_TRANSPOSE_2D && - device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { - return false; - } // Channel-contiguous format is not supported yet. return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->src[1]->type == GGML_TYPE_F32 && @@ -14528,21 +14445,21 @@ ggml_backend_reg_t ggml_backend_vk_reg() { } // Extension availability -static bool ggml_vk_instance_validation_ext_available() { +static bool ggml_vk_instance_layer_settings_available() { #ifdef GGML_VULKAN_VALIDATE // Check if validation layer provides the extension const std::string layer_name = "VK_LAYER_KHRONOS_validation"; for (const auto& layer : vk::enumerateInstanceLayerProperties()) { if (layer_name == layer.layerName.data()) { for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { - if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) { return true; } } } } - std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl; #endif return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index e9bdbf7d..875c012c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -32,22 +32,12 @@ layout(push_constant) uniform parameter { uint32_t Cin; uint32_t N; - // Tensor spatial sizes: kernel, input, output - uint32_t KW; - uint32_t KH; + // Tensor spatial sizes: input, output uint32_t W; uint32_t H; uint32_t OW; uint32_t OH; - // Parameters: stride, padding, dilation - 0=y, 1=x - uint32_t s0; - uint32_t s1; - uint32_t p0; - uint32_t p1; - uint32_t d0; - uint32_t d1; - // Strides in elements uint32_t nb01; uint32_t nb02; @@ -77,13 +67,14 @@ layout(constant_id = 3) const uint BS_NPQ = 128; layout(constant_id = 4) const uint TS_K = 8; layout(constant_id = 5) const uint use_collectives = 1; layout(constant_id = 6) const uint SHMEM_PAD = 4; - +// Stride, padding, dilation layout(constant_id = 7) const uint s0 = 1; layout(constant_id = 8) const uint s1 = 1; layout(constant_id = 9) const uint p0 = 0; layout(constant_id = 10) const uint p1 = 0; layout(constant_id = 11) const uint d0 = 1; layout(constant_id = 12) const uint d1 = 1; +// Kernel spatial sizes layout(constant_id = 13) const uint KW = 1; layout(constant_id = 14) const uint KH = 1; @@ -138,7 +129,7 @@ P,Q=OH,OW */ uint32_t B_idx_K = gl_WorkGroupID.x; -uint32_t B_idx_NPQ = gl_WorkGroupID.y; +uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512; uint32_t T_y = tid / NT_NPQ; uint32_t T_x = tid % NT_NPQ; @@ -178,6 +169,10 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T #endif void main() { + if (B_idx_NPQ * BS_NPQ >= NPQ) { + return; + } + #ifdef COOPMAT2 coopmat matC; matC = coopmat(0.0); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index 4cb29238..e5cc7ff8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -7,35 +7,85 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, + const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + // Compute starting index in matrix B for this superblock const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + + // Precompute indices for quantization lookup tables + const uint qh_base = 2 * ib32; + const uint qs_base = 4 * ib32; + const uint sc_index = ib32 / 2; + const uint sc_shift = 6 * (ib32 & 1); + + // Loop over rows in the superblock [[unroll]] for (uint n = 0; n < num_rows; ++n) { + // Load per-block scales and shift for quantization const uint16_t[4] scales = data_a[ibi].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = data_a[ibi].scales[sc_index] >> sc_shift; - const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + // Temporary caches for decoding + FLOAT_TYPE dl_cache[4]; + uint16_t gvf_cache[4]; + float delta_cache[4]; + + // Precompute the multiplier and lookup values for 4 sub-blocks [[unroll]] for (uint l = 0; l < 4; ++l) { - const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); - const uint qs = data_a[ibi].qs[4 * ib32 + l]; - const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; - const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1)); + const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1)); + const uint qs = data_a[ibi].qs[qs_base + l]; + gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)]; + delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + } - const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Loop over columns of the output + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + // Compute base index for matrix B + const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4; + vec4 b_vals[8]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); - vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int k = 0; k < 4; ++k) { - sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, - fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); - } - temp[j][n] = fma(dl, sum, temp[j][n]); + // Load 8 vec4 values from matrix B + [[unroll]] for (int idx = 0; idx < 8; ++idx) { + b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]); } + + FLOAT_TYPE col_sum = FLOAT_TYPE(0.0); + + // Loop over sub-blocks + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint16_t grid = gvf_cache[l]; + const float dl = dl_cache[l]; + + // Decode 8 2-bit fbits from gvf_cache + float f0 = float(bitfieldExtract(grid, 0, 2)); + float f1 = float(bitfieldExtract(grid, 2, 2)); + float f2 = float(bitfieldExtract(grid, 4, 2)); + float f3 = float(bitfieldExtract(grid, 6, 2)); + float f4 = float(bitfieldExtract(grid, 8, 2)); + float f5 = float(bitfieldExtract(grid, 10, 2)); + float f6 = float(bitfieldExtract(grid, 12, 2)); + float f7 = float(bitfieldExtract(grid, 14, 2)); + + // Pack into vec4 for vectorized FMA + const vec4 fbits_v0 = vec4(f0, f1, f2, f3); + const vec4 fbits_v1 = vec4(f4, f5, f6, f7); + const vec4 delta_v = vec4(delta_cache[l]); + + // Vectorized fused multiply-add + vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0)); + sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v); + + // Horizontal add to get scalar sum + FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w; + + // Accumulate to column sum + col_sum = fma(dl, sum, col_sum); + } + // Write result to temporary buffer + temp[j][n] += col_sum; } ibi += num_blocks_per_row; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index f3c81768..5abd2f6f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -8,6 +8,7 @@ layout (push_constant) uniform parameter uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; uint misalign_offsets; + uint circular; uint lp0; uint rp0; uint lp1; uint rp1; @@ -18,6 +19,10 @@ layout (push_constant) uniform parameter uint get_aoffset() { return p.misalign_offsets >> 16; } uint get_doffset() { return p.misalign_offsets & 0xFFFF; } +uint wrap_around(int coord, uint size) { + return (uint(coord + int(size))) % size; // add size to avoid issues with negative +} + layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; @@ -40,10 +45,20 @@ void main() { const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && - i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && - i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && - i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + if (p.circular != 0u) { + const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00); + const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01); + const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02); + const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03); + const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00; + data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]); + } else { + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); + } + - data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 3a47949d..9d6d3665 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -131,8 +131,12 @@ void main() { rms_norm(num_blocks); } else if (num_blocks > 16) { rms_norm(32); - } else if (num_blocks > 8) { + } else if (num_blocks > 12) { rms_norm(16); + } else if (num_blocks > 10) { + rms_norm(12); + } else if (num_blocks > 8) { + rms_norm(10); } else if (num_blocks > 4) { rms_norm(8); } else if (num_blocks == 4) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp index 253a9e7e..3b651450 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -5,8 +5,9 @@ layout (constant_id = 1) const uint N = 64; layout (constant_id = 2) const uint K = 32; +layout (constant_id = 3) const uint BATCH_N = 32; -layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in; uint a_base, b_base, x_base; @@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) { data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); } -shared FLOAT_TYPE shA[N * N]; -shared FLOAT_TYPE shB[N * K]; +shared FLOAT_TYPE shA[BATCH_N * N]; +shared FLOAT_TYPE shB[BATCH_N * K]; void main() { const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; @@ -39,34 +40,42 @@ void main() { b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; - // Load the A matrix into shA - [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { - shA[idx] = get_a(idx / N, idx % N); - } - } - // Load the B matrix into shB - [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { - shB[idx] = get_b(idx / K, idx % K); - } - } - barrier(); - FLOAT_TYPE X[N]; - // Each thread solves one column - if (tid < K) { - [[unroll]] for (int r = 0; r < N; ++r) { - FLOAT_TYPE b = shB[r * K + tid]; - // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] - [[unroll]] for (int c = 0; c < r; ++c) { - b -= shA[r * N + c] * X[c]; + + // Loop over batches of rows + [[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) { + const uint cur_N = min(BATCH_N, N - row_base); + + // Load the A matrix batch into shA + [[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) { + shA[idx] = get_a(row_base + idx / N, idx % N); } - FLOAT_TYPE x = b / shA[r * N + r]; - X[r] = x; - store_x(r, tid, x); } + // Load the B matrix batch into shB + [[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) { + shB[idx] = get_b(row_base + idx / K, idx % K); + } + } + barrier(); + + // Each thread solves one column + if (tid < K) { + [[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) { + uint r = row_base + row_offset; + FLOAT_TYPE b = shB[row_offset * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[row_offset * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[row_offset * N + r]; + X[r] = x; + store_x(r, tid, x); + } + } + barrier(); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp index cd858b7d..49d4ab8e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols_input; dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (col < s) { @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) { } } - if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (col < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + col] = dst_row[col].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + col] = dst_row[col]; + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index bc1c278b..5cd0785d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit } void main() { - const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; if (row >= n_rows) { return; } @@ -83,17 +83,18 @@ void main() { const uint logits_offset = n_experts * row; const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; + const uint lane = gl_SubgroupInvocationID; float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; + const uint expert = i + lane; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); + softmax_warp_inplace(wt, n_experts, lane, false); } // at this point, each thread holds a portion of softmax, @@ -111,11 +112,11 @@ void main() { for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; - uint max_expert = gl_LocalInvocationID.x; + uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { - const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + const uint expert = lane + i * WARP_SIZE; if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; @@ -132,11 +133,11 @@ void main() { } } - if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((k & (WARP_SIZE - 1)) == lane) { output_weights[k / WARP_SIZE] = max_val; } - if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((max_expert & (WARP_SIZE - 1)) == lane) { wt[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; @@ -158,12 +159,12 @@ void main() { } if (late_softmax) { - softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + softmax_warp_inplace(output_weights, n_expert_used, lane, true); } [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { - uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { weights[weights_offset + idx] = output_weights[i]; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index c902e602..0b757f38 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -37,6 +38,7 @@ shared int counts[SUBGROUP_SIZE]; shared int sh_min_idx; shared uint sh_total; shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; +shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE]; // Map float values to uint such that comparisons still work. // Positive values set the high bit, negative values are inverted. @@ -60,7 +62,7 @@ void topk(const uint row) { const uint row_offset = row * p.ncols_input; dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -68,7 +70,7 @@ void topk(const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (tid < s) { @@ -98,7 +100,7 @@ void topk(const uint row) { uint range_max = 0xFF800000; // How many are above the current range, and how many we need to find. uint total = 0; - uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); while (mask != 0) { barrier(); @@ -139,7 +141,7 @@ void topk(const uint row) { range_max = range_min + ((min_idx + 1) << shift); range_min = range_min + (min_idx << shift); - if (total == p.ncols_output) { + if (total == p.k) { break; } total -= counts[min_idx]; @@ -155,37 +157,82 @@ void topk(const uint row) { // We need to compact these values to the start of the dst_row array. // Have each subgroup count how many items it'll store, so other // subgroups can compute their base offset. - bool top = f2ui(intBitsToFloat(v.y)) >= range_min; - uvec4 b = subgroupBallot(top); - uint bit_count = subgroupBallotBitCount(b); - if ((tid % SUBGROUP_SIZE) == 0) { - offset_partials[tid / SUBGROUP_SIZE] = bit_count; - } - barrier(); - - uint out_idx = 0; - [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { - if (i < tid / SUBGROUP_SIZE) { - out_idx += offset_partials[i]; + // Values strictly greater than range_min must be stored. For values equal + // to range_min, there can be ties and it's possible we'll need to store + // an arbitrary subset of them. + // If total == p.k, have a fast path where we don't need to handle ties. + if (total == p.k) { + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; } - } + barrier(); - uint bit_count_ex = subgroupBallotExclusiveBitCount(b); - if (top) { - // TODO: Copy directly to the output? - dst_row[out_idx + bit_count_ex] = v; + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + } else { + bool top = f2ui(intBitsToFloat(v.y)) > range_min; + bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min; + uvec4 b_top = subgroupBallot(top); + uvec4 b_eq_min = subgroupBallot(eq_min); + uint bit_count_top = subgroupBallotBitCount(b_top); + uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count_top; + eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min; + } + barrier(); + + uint out_idx = 0; + uint eq_min_base = 0; + uint eq_min_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + eq_min_idx += eq_min_partials[i]; + } + eq_min_base += offset_partials[i]; + } + // range_min values are stored at the end + eq_min_idx += eq_min_base; + + uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top); + uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex_top] = v; + } + if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) { + dst_row[eq_min_idx + bit_count_ex_eq_min] = v; + } } barrier(); } - if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (tid < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + tid] = dst_row[tid].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + tid] = dst_row[tid]; + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } } } } diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index 1c9e0bc0..fc0196eb 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -124,6 +124,13 @@ static void ggml_print_backtrace_symbols(void) { int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); } +#elif defined(__APPLE__) +#include +static void ggml_print_backtrace_symbols(void) { + void * trace[100]; + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); +} #else static void ggml_print_backtrace_symbols(void) { // platform not supported @@ -135,6 +142,20 @@ void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } +#if defined(__APPLE__) + // On macOS, fork+debugger attachment is problematic due to: + // 1. libdispatch "poisons" forked child processes + // 2. lldb has issues attaching to parent from forked child + // Use simple backtrace() instead to avoid Terminal.app crashes + const char * GGML_BACKTRACE_LLDB = getenv("GGML_BACKTRACE_LLDB"); + if (!GGML_BACKTRACE_LLDB) { + fprintf(stderr, "WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\n"); + fprintf(stderr, "WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\n"); + fprintf(stderr, "See: https://github.com/ggml-org/llama.cpp/pull/17869\n"); + ggml_print_backtrace_symbols(); + return; + } +#endif #if defined(__linux__) FILE * f = fopen("/proc/self/status", "r"); size_t size = 0; @@ -4896,6 +4917,8 @@ static struct ggml_tensor * ggml_interpolate_impl( int64_t ne3, uint32_t mode) { GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); + // TODO: implement antialias for modes other than bilinear + GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -4950,6 +4973,18 @@ struct ggml_tensor * ggml_pad( return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); } +// ggml_pad_circular + +struct ggml_tensor * ggml_pad_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { + return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); +} + struct ggml_tensor * ggml_pad_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -4976,6 +5011,7 @@ struct ggml_tensor * ggml_pad_ext( ggml_set_op_params_i32(result, 5, rp2); ggml_set_op_params_i32(result, 6, lp3); ggml_set_op_params_i32(result, 7, rp3); + ggml_set_op_params_i32(result, 8, 0); // not circular by default result->op = GGML_OP_PAD; @@ -4984,6 +5020,25 @@ struct ggml_tensor * ggml_pad_ext( return result; } +// ggml_pad_ext_circular + +struct ggml_tensor * ggml_pad_ext_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ) { + struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + ggml_set_op_params_i32(result, 8, 1); // circular + return result; +} + // ggml_pad_reflect_1d struct ggml_tensor * ggml_pad_reflect_1d( diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp index d950dbdf..f91d4fab 100644 --- a/ml/backend/ggml/ggml/src/gguf.cpp +++ b/ml/backend/ggml/ggml/src/gguf.cpp @@ -1172,7 +1172,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo struct gguf_writer_base { size_t written_bytes {0u}; - ~gguf_writer_base(void) {} + ~gguf_writer_base(void) = default; // we bet on devirtualization virtual void write(int8_t val) = 0; From 56b8fb024cb530c738dc6ddb5cd76714255f1a6e Mon Sep 17 00:00:00 2001 From: Julia Scheaffer Date: Wed, 10 Dec 2025 16:07:48 -0600 Subject: [PATCH 17/35] cmd/bench: fix options table in cmd/bench/README.md (#13216) --- cmd/bench/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/bench/README.md b/cmd/bench/README.md index 210cc4a2..e45dc94e 100644 --- a/cmd/bench/README.md +++ b/cmd/bench/README.md @@ -54,6 +54,7 @@ benchstat -col /name gemma.bench ## Command Line Options | Option | Description | Default | +|----------|-------------|---------| | -model | Comma-separated list of models to benchmark | (required) | | -epochs | Number of iterations per model | 1 | | -max-tokens | Maximum tokens for model response | 0 (unlimited) | From dac4f17fea99dc18628d743e80f91dcd15ab4bce Mon Sep 17 00:00:00 2001 From: Eloi Torrents Date: Wed, 10 Dec 2025 23:16:58 +0100 Subject: [PATCH 18/35] cmd/bench: fix binary name in README (#13276) --- cmd/bench/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/bench/README.md b/cmd/bench/README.md index e45dc94e..cf261dd0 100644 --- a/cmd/bench/README.md +++ b/cmd/bench/README.md @@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa ``` go build -o ollama-bench bench.go -./bench -model gpt-oss:20b -epochs 6 -format csv +./ollama-bench -model gpt-oss:20b -epochs 6 -format csv ``` Using Go Run (without building) @@ -29,26 +29,26 @@ go run bench.go -model gpt-oss:20b -epochs 3 ### Basic Example ``` -./bench -model gemma3 -epochs 6 +./ollama-bench -model gemma3 -epochs 6 ``` ### Benchmark Multiple Models ``` -./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench +./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench benchstat -col /name gemma.bench ``` ### With Image Prompt ``` -./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" +./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" ``` ### Advanced Example ``` -./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv +./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv ``` ## Command Line Options From 1c4e85b4df1a8ebcb0f578ea423cc1a0d0adf873 Mon Sep 17 00:00:00 2001 From: EasonLin <75676459+Eason023@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:28:41 +0800 Subject: [PATCH 19/35] routes: add logprobs in tool calls (#13238) --- server/routes.go | 12 ++++- server/routes_generate_test.go | 89 ++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/server/routes.go b/server/routes.go index 4dd870ed..bbf6b9b9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2195,7 +2195,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 { slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) ch <- res } else { @@ -2235,8 +2235,16 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.ToolCalls = toolCalls res.Message.Content = "" } else if res.Message.Thinking != "" { - // don't return + // don't return, fall through to send } else { + // Send logprobs while content is being buffered by the parser for tool calls + if len(res.Logprobs) > 0 && !r.Done { + logprobRes := res + logprobRes.Message.Content = "" + logprobRes.Message.ToolCalls = nil + ch <- logprobRes + } + if r.Done { res.Message.Content = toolParser.Content() ch <- res diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index a9931ea2..13befff2 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -708,6 +708,95 @@ func TestGenerateChat(t *testing.T) { } }) + t.Run("messages with tools and logprobs (streaming)", func(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + defer wg.Done() + + // Simulate a response where logprobs are sent while the tool call is being buffered + responses := []llm.CompletionResponse{ + { + Content: `{ "name": "get_weather"`, + Done: false, + Logprobs: []llm.Logprob{{}}, + }, + { + Content: `,"arguments":{"location":"Seattle, WA","unit":"celsius"}}`, + Done: false, + Logprobs: []llm.Logprob{{}}, + }, + { + Content: ``, + Done: true, + DoneReason: llm.DoneReasonStop, + Logprobs: nil, + }, + } + + for _, resp := range responses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fn(resp) + time.Sleep(10 * time.Millisecond) + } + } + return nil + } + + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + }, + Tools: tools, + Stream: &stream, + }) + + wg.Wait() + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + decoder := json.NewDecoder(w.Body) + var totalLogprobs int + + for { + var resp api.ChatResponse + if err := decoder.Decode(&resp); err == io.EOF { + break + } else if err != nil { + t.Fatal(err) + } + + totalLogprobs += len(resp.Logprobs) + } + + expectedLogprobs := 2 + if totalLogprobs != expectedLogprobs { + t.Errorf("expected %d logprobs, got %d", expectedLogprobs, totalLogprobs) + } + }) + t.Run("status error non-streaming", func(t *testing.T) { mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { return api.StatusError{ From a838421ea35366ea35a39102981c96a4a16658bb Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 11 Dec 2025 13:04:00 -0800 Subject: [PATCH 20/35] model: conversion and hyperparameter fixes for ministral and devstral (#13424) --- convert/convert.go | 2 + convert/convert_mistral.go | 35 ++++-- convert/convert_mistral_causal.go | 181 ++++++++++++++++++++++++++++ model/models/mistral3/model_text.go | 44 ++++++- 4 files changed, 250 insertions(+), 12 deletions(-) create mode 100644 convert/convert_mistral_causal.go diff --git a/convert/convert.go b/convert/convert.go index bc110c6f..f0846795 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &llama4Model{} case "Mistral3ForConditionalGeneration": conv = &mistral3Model{} + case "Ministral3ForCausalLM": + conv = &mistral3CausalModel{} case "MixtralForCausalLM": conv = &mixtralModel{} case "GemmaForCausalLM": diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index 81774853..f11bd964 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -30,13 +30,15 @@ type mistral3Model struct { HiddenAct string `json:"hidden_act"` VocabSize uint32 `json:"vocab_size"` RopeParameters struct { - BetaFast float32 `json:"beta_fast"` - BetaSlow float32 `json:"beta_slow"` - Factor float32 `json:"factor"` - ScalingBeta float32 `json:"llama_4_scaling_beta"` - OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` - RopeType string `json:"rope_type"` - RopeTheta float32 `json:"rope_theta"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + Factor float32 `json:"factor"` + Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"` + OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + Mscale *float32 `json:"mscale"` + MscaleAllDim *float32 `json:"mscale_all_dim"` } `json:"rope_parameters"` } `json:"text_config"` VisionModel struct { @@ -50,6 +52,9 @@ type mistral3Model struct { HeadDim uint32 `json:"head_dim"` HiddenAct string `json:"hidden_act"` RopeTheta float32 `json:"rope_theta"` + RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + } `json:"rope_parameters"` } `json:"vision_config"` MultiModalProjectorBias bool `json:"multimodal_projector_bias"` ProjectorHiddenAct string `json:"projector_hidden_act"` @@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.attention.value_length"] = p.TextModel.HeadDim kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads) kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta) + kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor + kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType + kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast + kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow + if p.TextModel.RopeParameters.Mscale != nil { + kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale + } + if p.TextModel.RopeParameters.MscaleAllDim != nil { + kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim + } if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 { kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings - kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta + } + if p.TextModel.RopeParameters.Llama4ScalingBeta != nil { + kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta } // Vision configuration @@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value - kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta + kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta) // Multimodal configuration kv["mistral3.image_token_index"] = p.ImageTokenIndex diff --git a/convert/convert_mistral_causal.go b/convert/convert_mistral_causal.go new file mode 100644 index 00000000..99a48373 --- /dev/null +++ b/convert/convert_mistral_causal.go @@ -0,0 +1,181 @@ +package convert + +import ( + "cmp" + "fmt" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type mistral3CausalModel struct { + ModelParameters + + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RopeTheta float32 `json:"rope_theta"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + SlidingWindow *uint32 `json:"sliding_window"` + HiddenAct string `json:"hidden_act"` + VocabSize uint32 `json:"vocab_size"` + RopeParameters struct { + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + Factor float32 `json:"factor"` + Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"` + OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + Mscale *float32 `json:"mscale"` + MscaleAllDim *float32 `json:"mscale_all_dim"` + } `json:"rope_parameters"` +} + +func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "mistral3" + kv["mistral3.vocab_size"] = p.VocabSize + + // Text configuration + kv["mistral3.block_count"] = p.NumHiddenLayers + kv["mistral3.context_length"] = p.MaxPositionEmbeddings + kv["mistral3.embedding_length"] = p.HiddenSize + kv["mistral3.feed_forward_length"] = p.IntermediateSize + kv["mistral3.attention.head_count"] = p.NumAttentionHeads + kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads + kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["mistral3.attention.key_length"] = p.HeadDim + kv["mistral3.attention.value_length"] = p.HeadDim + kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads) + kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta) + kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor + kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType + kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast + kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow + + if p.RopeParameters.Mscale != nil { + kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale + } + + if p.RopeParameters.MscaleAllDim != nil { + kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim + } + + if p.RopeParameters.OrigMaxPositionEmbeddings > 0 { + kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings + kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta + } + + if p.RopeParameters.Llama4ScalingBeta != nil { + kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta + } + + return kv +} + +func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + for _, t := range ts { + if !strings.HasPrefix(t.Name(), "v.") { + if strings.HasSuffix(t.Name(), ".attn_q.weight") || + strings.HasSuffix(t.Name(), ".attn_k.weight") { + t.SetRepacker(p.repack) + } + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *mistral3CausalModel) Replacements() []string { + return []string{ + "model.norm", "output_norm", + "model.", "", + "layers", "blk", + "transformer.layers", "blk", + "vision_tower", "v", + "ln_pre", "encoder_norm", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "embed_tokens", "token_embd", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "attention.q_proj", "attn_q", + "attention.k_proj", "attn_k", + "attention.v_proj", "attn_v", + "attention.o_proj", "attn_output", + "attention_norm", "attn_norm", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", + "multi_modal_projector", "mm", + "ffn_norm", "ffn_norm", + "lm_head", "output", + } +} + +func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) { + var dims []int + for _, dim := range shape { + dims = append(dims, int(dim)) + } + + var heads uint32 + if strings.HasSuffix(name, ".attn_q.weight") { + heads = p.NumAttentionHeads + } else if strings.HasSuffix(name, ".attn_k.weight") { + heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) + } else { + return nil, fmt.Errorf("unknown tensor for repack: %s", name) + } + + n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := n.Reshape(dims...); err != nil { + return nil, err + } + + if err := n.Transpose(); err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 1) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index ebb7b3aa..01eca1c5 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -17,10 +18,41 @@ type TextOptions struct { eps, ropeBase, ropeScale float32 ropeOrigPosEmbeddings int ropeScalingBeta float32 + ropeType string + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + ropeMscale float32 + ropeMscaleAllDim float32 } func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale) + var ropeOpts []func(*rope.Options) + if o.ropeType == "yarn" { + getMscale := func(scale, mscale float64) float64 { + if scale <= 1.0 { + return 1.0 + } + return 0.1*mscale*math.Log(scale) + 1.0 + } + + var attnFactor float32 + if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 { + attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim))) + } else { + attnFactor = float32(getMscale(float64(o.ropeScale), 1)) + } + + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + rope.WithAttentionFactor(attnFactor), + rope.WithBetaFast(o.ropeBetaFast), + rope.WithBetaSlow(o.ropeBetaSlow), + ) + } + + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...) } type TextModel struct { @@ -150,9 +182,15 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.scaling.factor", 1), + ropeScale: c.Float("rope.scaling.factor", 1.0), ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")), - ropeScalingBeta: c.Float("rope.scaling_beta"), + ropeScalingBeta: c.Float("rope.scaling_beta", 0.1), + ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0), + ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), + ropeType: c.String("rope.scaling.type"), + ropeMscale: c.Float("rope.scaling.mscale"), + ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1), }, } } From 48e78e9be1cb39473a8220dd0d293c9e65ffb07d Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 11 Dec 2025 14:47:55 -0800 Subject: [PATCH 21/35] template: add yesterdayDate helper function (#13431) --- template/template.go | 3 ++ template/template_test.go | 67 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/template/template.go b/template/template.go index c90190d7..39b6ad7b 100644 --- a/template/template.go +++ b/template/template.go @@ -127,6 +127,9 @@ var funcs = template.FuncMap{ // Default format is YYYY-MM-DD return time.Now().Format("2006-01-02") }, + "yesterdayDate": func(args ...string) string { + return time.Now().AddDate(0, 0, -1).Format("2006-01-02") + }, "toTypeScriptType": func(v any) string { if param, ok := v.(api.ToolProperty); ok { return param.ToTypeScriptType() diff --git a/template/template_test.go b/template/template_test.go index 74388d6e..fbea0ed0 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -10,6 +10,7 @@ import ( "slices" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" @@ -451,6 +452,72 @@ func TestExecuteWithSuffix(t *testing.T) { } } +func TestDateFunctions(t *testing.T) { + t.Run("currentDate", func(t *testing.T) { + tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Today is {{ currentDate }}") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil { + t.Fatal(err) + } + + expected := "Hello Today is " + time.Now().Format("2006-01-02") + if b.String() != expected { + t.Errorf("got %q, want %q", b.String(), expected) + } + }) + + t.Run("yesterdayDate", func(t *testing.T) { + tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Yesterday was {{ yesterdayDate }}") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil { + t.Fatal(err) + } + + expected := "Hello Yesterday was " + time.Now().AddDate(0, 0, -1).Format("2006-01-02") + if b.String() != expected { + t.Errorf("got %q, want %q", b.String(), expected) + } + }) + + t.Run("yesterdayDate format", func(t *testing.T) { + tmpl, err := Parse("{{- range .Messages }}{{ end }}{{ yesterdayDate }}") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil { + t.Fatal(err) + } + + // Verify the format matches YYYY-MM-DD + result := b.String() + if len(result) != 10 { + t.Errorf("expected date length 10, got %d: %q", len(result), result) + } + + // Parse and verify it's a valid date + parsed, err := time.Parse("2006-01-02", result) + if err != nil { + t.Errorf("failed to parse date %q: %v", result, err) + } + + // Verify it's yesterday + yesterday := time.Now().AddDate(0, 0, -1) + if parsed.Year() != yesterday.Year() || parsed.Month() != yesterday.Month() || parsed.Day() != yesterday.Day() { + t.Errorf("expected yesterday's date, got %v", parsed) + } + }) +} + func TestCollate(t *testing.T) { cases := []struct { name string From 3475d915cb0882042041d7746e6baf888469c3e0 Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Thu, 11 Dec 2025 15:36:31 -0800 Subject: [PATCH 22/35] embeddings: modified batch size (#13429) This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch. Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash. This change ensures all tokens stay in one batch and prevents crashes. Fixes: #12938 #13054 Co-authored-by: Jesse Gross --- integration/embed_test.go | 57 +++++++++++++++++++++++++++++++++++ llama/llama.go | 3 +- llm/server.go | 7 +++++ runner/llamarunner/runner.go | 2 +- runner/ollamarunner/runner.go | 16 +++++++--- 5 files changed, 78 insertions(+), 7 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index f01903ee..e4506673 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) { } } +// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes. +func TestEmbedLargeInput(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute) + defer mcancel() + + // Test with progressively larger inputs + testCases := []struct { + name string + inputWords int + }{ + {"medium_input_256_words", 256}, + {"large_input_512_words", 512}, + {"very_large_input_800_words", 800}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + words := make([]string, tc.inputWords) + for i := range words { + words[i] = "word" + } + input := strings.Join(words, " ") + + req := api.EmbedRequest{ + Model: model, + Input: input, + KeepAlive: &api.Duration{Duration: 30 * time.Second}, + } + + res, err := embedTestHelper(mctx, client, t, req) + if err != nil { + t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err) + } + + if len(res.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) == 0 { + t.Fatal("expected non-empty embedding") + } + + t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount) + }) + } + }) + } +} + // TestEmbedStatusCode tests that errors from the embedding endpoint // properly preserve their HTTP status codes when returned to the client. // This test specifically checks the error handling path in EmbedHandler diff --git a/llama/llama.go b/llama/llama.go index 582d4128..70bf3b9c 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -121,7 +121,8 @@ type ContextParams struct { func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) - params.n_batch = C.uint(batchSize) + params.n_batch = C.uint(batchSize * numSeqMax) + params.n_ubatch = C.uint(batchSize) params.n_seq_max = C.uint(numSeqMax) params.n_threads = C.int(threads) params.n_threads_batch = params.n_threads diff --git a/llm/server.go b/llm/server.go index 1c47601f..5c232f0f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -474,6 +474,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) } + // Check if embedding model and adjust batch size accordingly + _, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())] + if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx { + s.loadRequest.BatchSize = s.options.NumCtx + slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize) + } + kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize), s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 0f32fd2a..cb4bbe50 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -842,7 +842,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType) + ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d0427662..a756cba2 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1203,16 +1203,22 @@ func (s *Server) allocModel( return errors.New("loras are not yet implemented") } + if s.model.Config().Cache == nil { + if parallel > 1 { + parallel = 1 + slog.Warn("model does not support caching, disabling parallel processing") + } + if s.batchSize < kvSize { + s.batchSize = kvSize + slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize) + } + } + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { return err } - if !s.cache.enabled && parallel > 1 { - parallel = 1 - slog.Warn("model does not support caching, disabling parallel processing") - } - s.parallel = parallel s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) From 1eb5e759724a10fea90a2f8e9ab7c292e7287191 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 11 Dec 2025 15:37:10 -0800 Subject: [PATCH 23/35] openai: add v1/responses support (#13351) Only supporting the stateless part of the API. Doc updates to come once this is shipped. Closes: #9659 --- middleware/openai.go | 108 +++ openai/openai.go | 53 +- openai/responses.go | 1004 +++++++++++++++++++++++++ openai/responses_test.go | 1543 ++++++++++++++++++++++++++++++++++++++ server/routes.go | 2 + 5 files changed, 2688 insertions(+), 22 deletions(-) create mode 100644 openai/responses.go create mode 100644 openai/responses_test.go diff --git a/middleware/openai.go b/middleware/openai.go index b2e43f16..5e526416 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc { c.Next() } } + +type ResponsesWriter struct { + BaseWriter + converter *openai.ResponsesStreamConverter + model string + stream bool + responseID string + itemID string +} + +func (w *ResponsesWriter) writeEvent(eventType string, data any) error { + d, err := json.Marshal(data) + if err != nil { + return err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d))) + if err != nil { + return err + } + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + return nil +} + +func (w *ResponsesWriter) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + if err := json.Unmarshal(data, &chatResponse); err != nil { + return 0, err + } + + if w.stream { + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + + events := w.converter.Process(chatResponse) + for _, event := range events { + if err := w.writeEvent(event.Event, event.Data); err != nil { + return 0, err + } + } + return len(data), nil + } + + // Non-streaming response + w.ResponseWriter.Header().Set("Content-Type", "application/json") + response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse) + return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) +} + +func (w *ResponsesWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + return w.writeResponse(data) +} + +func ResponsesMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.ResponsesRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + chatReq, err := openai.FromResponsesRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + // Check if client requested streaming (defaults to false) + streamRequested := req.Stream != nil && *req.Stream + + // Pass streaming preference to the underlying chat request + chatReq.Stream = &streamRequested + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(chatReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + responseID := fmt.Sprintf("resp_%d", rand.Intn(999999)) + itemID := fmt.Sprintf("msg_%d", rand.Intn(999999)) + + w := &ResponsesWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model), + model: req.Model, + stream: streamRequested, + responseID: responseID, + itemID: itemID, + } + + // Set headers based on streaming mode + if streamRequested { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + } + + c.Writer = w + c.Next() + } +} diff --git a/openai/openai.go b/openai/openai.go index 4713d481..9dcba300 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } } - types := []string{"jpeg", "jpg", "png", "webp"} - valid := false - // support blank mime type to match api/chat taking just unadorned base64 - if strings.HasPrefix(url, "data:;base64,") { - url = strings.TrimPrefix(url, "data:;base64,") - valid = true - } - for _, t := range types { - prefix := "data:image/" + t + ";base64," - if strings.HasPrefix(url, prefix) { - url = strings.TrimPrefix(url, prefix) - valid = true - break - } - } - - if !valid { - return nil, errors.New("invalid image input") - } - - img, err := base64.StdEncoding.DecodeString(url) + img, err := decodeImageURL(url) if err != nil { - return nil, errors.New("invalid message format") + return nil, err } messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) @@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string { return "" } +// decodeImageURL decodes a base64 data URI into raw image bytes. +func decodeImageURL(url string) (api.ImageData, error) { + types := []string{"jpeg", "jpg", "png", "webp"} + + // Support blank mime type to match /api/chat's behavior of taking just unadorned base64 + if strings.HasPrefix(url, "data:;base64,") { + url = strings.TrimPrefix(url, "data:;base64,") + } else { + valid := false + for _, t := range types { + prefix := "data:image/" + t + ";base64," + if strings.HasPrefix(url, prefix) { + url = strings.TrimPrefix(url, prefix) + valid = true + break + } + } + if !valid { + return nil, errors.New("invalid image input") + } + } + + img, err := base64.StdEncoding.DecodeString(url) + if err != nil { + return nil, errors.New("invalid image input") + } + return img, nil +} + // FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { apiToolCalls := make([]api.ToolCall, len(toolCalls)) diff --git a/openai/responses.go b/openai/responses.go new file mode 100644 index 00000000..8f6b1d94 --- /dev/null +++ b/openai/responses.go @@ -0,0 +1,1004 @@ +package openai + +import ( + "encoding/json" + "fmt" + "math/rand" + + "github.com/ollama/ollama/api" +) + +// ResponsesContent is a discriminated union for input content types. +// Concrete types: ResponsesTextContent, ResponsesImageContent +type ResponsesContent interface { + responsesContent() // unexported marker method +} + +type ResponsesTextContent struct { + Type string `json:"type"` // always "input_text" + Text string `json:"text"` +} + +func (ResponsesTextContent) responsesContent() {} + +type ResponsesImageContent struct { + Type string `json:"type"` // always "input_image" + // TODO(drifkin): is this really required? that seems verbose and a default is specified in the docs + Detail string `json:"detail"` // required + FileID string `json:"file_id,omitempty"` // optional + ImageURL string `json:"image_url,omitempty"` // optional +} + +func (ResponsesImageContent) responsesContent() {} + +// ResponsesOutputTextContent represents output text from a previous assistant response +// that is being passed back as part of the conversation history. +type ResponsesOutputTextContent struct { + Type string `json:"type"` // always "output_text" + Text string `json:"text"` +} + +func (ResponsesOutputTextContent) responsesContent() {} + +type ResponsesInputMessage struct { + Type string `json:"type"` // always "message" + Role string `json:"role"` // one of `user`, `system`, `developer` + Content []ResponsesContent `json:"content,omitempty"` +} + +func (m *ResponsesInputMessage) UnmarshalJSON(data []byte) error { + var aux struct { + Type string `json:"type"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + m.Type = aux.Type + m.Role = aux.Role + + if len(aux.Content) == 0 { + return nil + } + + // Try to parse content as a string first (shorthand format) + var contentStr string + if err := json.Unmarshal(aux.Content, &contentStr); err == nil { + m.Content = []ResponsesContent{ + ResponsesTextContent{Type: "input_text", Text: contentStr}, + } + return nil + } + + // Otherwise, parse as an array of content items + var rawItems []json.RawMessage + if err := json.Unmarshal(aux.Content, &rawItems); err != nil { + return fmt.Errorf("content must be a string or array: %w", err) + } + + m.Content = make([]ResponsesContent, 0, len(rawItems)) + for i, raw := range rawItems { + // Peek at the type field to determine which concrete type to use + var typeField struct { + Type string `json:"type"` + } + if err := json.Unmarshal(raw, &typeField); err != nil { + return fmt.Errorf("content[%d]: %w", i, err) + } + + switch typeField.Type { + case "input_text": + var content ResponsesTextContent + if err := json.Unmarshal(raw, &content); err != nil { + return fmt.Errorf("content[%d]: %w", i, err) + } + m.Content = append(m.Content, content) + case "input_image": + var content ResponsesImageContent + if err := json.Unmarshal(raw, &content); err != nil { + return fmt.Errorf("content[%d]: %w", i, err) + } + m.Content = append(m.Content, content) + case "output_text": + var content ResponsesOutputTextContent + if err := json.Unmarshal(raw, &content); err != nil { + return fmt.Errorf("content[%d]: %w", i, err) + } + m.Content = append(m.Content, content) + default: + return fmt.Errorf("content[%d]: unknown content type: %s", i, typeField.Type) + } + } + + return nil +} + +type ResponsesOutputMessage struct{} + +// ResponsesInputItem is a discriminated union for input items. +// Concrete types: ResponsesInputMessage (more to come) +type ResponsesInputItem interface { + responsesInputItem() // unexported marker method +} + +func (ResponsesInputMessage) responsesInputItem() {} + +// ResponsesFunctionCall represents an assistant's function call in conversation history. +type ResponsesFunctionCall struct { + ID string `json:"id,omitempty"` // item ID + Type string `json:"type"` // always "function_call" + CallID string `json:"call_id"` // the tool call ID + Name string `json:"name"` // function name + Arguments string `json:"arguments"` // JSON arguments string +} + +func (ResponsesFunctionCall) responsesInputItem() {} + +// ResponsesFunctionCallOutput represents a function call result from the client. +type ResponsesFunctionCallOutput struct { + Type string `json:"type"` // always "function_call_output" + CallID string `json:"call_id"` // links to the original function call + Output string `json:"output"` // the function result +} + +func (ResponsesFunctionCallOutput) responsesInputItem() {} + +// ResponsesReasoningInput represents a reasoning item passed back as input. +// This is used when the client sends previous reasoning back for context. +type ResponsesReasoningInput struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` // always "reasoning" + Summary []ResponsesReasoningSummary `json:"summary,omitempty"` + EncryptedContent string `json:"encrypted_content,omitempty"` +} + +func (ResponsesReasoningInput) responsesInputItem() {} + +// unmarshalResponsesInputItem unmarshals a single input item from JSON. +func unmarshalResponsesInputItem(data []byte) (ResponsesInputItem, error) { + var typeField struct { + Type string `json:"type"` + Role string `json:"role"` + } + if err := json.Unmarshal(data, &typeField); err != nil { + return nil, err + } + + // Handle shorthand message format: {"role": "...", "content": "..."} + // When type is empty but role is present, treat as a message + itemType := typeField.Type + if itemType == "" && typeField.Role != "" { + itemType = "message" + } + + switch itemType { + case "message": + var msg ResponsesInputMessage + if err := json.Unmarshal(data, &msg); err != nil { + return nil, err + } + return msg, nil + case "function_call": + var fc ResponsesFunctionCall + if err := json.Unmarshal(data, &fc); err != nil { + return nil, err + } + return fc, nil + case "function_call_output": + var output ResponsesFunctionCallOutput + if err := json.Unmarshal(data, &output); err != nil { + return nil, err + } + return output, nil + case "reasoning": + var reasoning ResponsesReasoningInput + if err := json.Unmarshal(data, &reasoning); err != nil { + return nil, err + } + return reasoning, nil + default: + return nil, fmt.Errorf("unknown input item type: %s", typeField.Type) + } +} + +// ResponsesInput can be either: +// - a string (equivalent to a text input with the user role) +// - an array of input items (see ResponsesInputItem) +type ResponsesInput struct { + Text string // set if input was a plain string + Items []ResponsesInputItem // set if input was an array +} + +func (r *ResponsesInput) UnmarshalJSON(data []byte) error { + // Try string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.Text = s + return nil + } + + // Otherwise, try array of input items + var rawItems []json.RawMessage + if err := json.Unmarshal(data, &rawItems); err != nil { + return fmt.Errorf("input must be a string or array: %w", err) + } + + r.Items = make([]ResponsesInputItem, 0, len(rawItems)) + for i, raw := range rawItems { + item, err := unmarshalResponsesInputItem(raw) + if err != nil { + return fmt.Errorf("input[%d]: %w", i, err) + } + r.Items = append(r.Items, item) + } + + return nil +} + +type ResponsesReasoning struct { + // originally: optional, default is per-model + Effort string `json:"effort,omitempty"` + + // originally: deprecated, use `summary` instead. One of `auto`, `concise`, `detailed` + GenerateSummary string `json:"generate_summary,omitempty"` + + // originally: optional, one of `auto`, `concise`, `detailed` + Summary string `json:"summary,omitempty"` +} + +type ResponsesTextFormat struct { + Type string `json:"type"` // "text", "json_schema" + Name string `json:"name,omitempty"` // for json_schema + Schema json.RawMessage `json:"schema,omitempty"` // for json_schema + Strict *bool `json:"strict,omitempty"` // for json_schema +} + +type ResponsesText struct { + Format *ResponsesTextFormat `json:"format,omitempty"` +} + +// ResponsesTool represents a tool in the Responses API format. +// Note: This differs from api.Tool which nests fields under "function". +type ResponsesTool struct { + Type string `json:"type"` // "function" + Name string `json:"name"` + Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` +} + +type ResponsesRequest struct { + Model string `json:"model"` + + // originally: optional, default is false + // for us: not supported + Background bool `json:"background"` + + // originally: optional `string | {id: string}` + // for us: not supported + Conversation json.RawMessage `json:"conversation"` + + // originally: string[] + // for us: ignored + Include []string `json:"include"` + + Input ResponsesInput `json:"input"` + + // optional, inserts a system message at the start of the conversation + Instructions string `json:"instructions,omitempty"` + + // optional, maps to num_predict + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + + Reasoning ResponsesReasoning `json:"reasoning"` + + // optional, default is 1.0 + Temperature *float64 `json:"temperature"` + + // optional, controls output format (e.g. json_schema) + Text *ResponsesText `json:"text,omitempty"` + + // optional, default is 1.0 + TopP *float64 `json:"top_p"` + + // optional, default is `"disabled"` + Truncation *string `json:"truncation"` + + Tools []ResponsesTool `json:"tools,omitempty"` + + // TODO(drifkin): tool_choice is not supported. We could support "none" by not + // passing tools, but the other controls like `"required"` cannot be generally + // supported. + + // optional, default is false + Stream *bool `json:"stream,omitempty"` +} + +// FromResponsesRequest converts a ResponsesRequest to api.ChatRequest +func FromResponsesRequest(r ResponsesRequest) (*api.ChatRequest, error) { + var messages []api.Message + + // Add instructions as system message if present + if r.Instructions != "" { + messages = append(messages, api.Message{ + Role: "system", + Content: r.Instructions, + }) + } + + // Handle simple string input + if r.Input.Text != "" { + messages = append(messages, api.Message{ + Role: "user", + Content: r.Input.Text, + }) + } + + // Handle array of input items + // Track pending reasoning to merge with the next assistant message + var pendingThinking string + + for _, item := range r.Input.Items { + switch v := item.(type) { + case ResponsesReasoningInput: + // Store thinking to merge with the next assistant message + pendingThinking = v.EncryptedContent + case ResponsesInputMessage: + msg, err := convertInputMessage(v) + if err != nil { + return nil, err + } + // If this is an assistant message, attach pending thinking + if msg.Role == "assistant" && pendingThinking != "" { + msg.Thinking = pendingThinking + pendingThinking = "" + } + messages = append(messages, msg) + case ResponsesFunctionCall: + // Convert function call to assistant message with tool calls + var args api.ToolCallFunctionArguments + if v.Arguments != "" { + if err := json.Unmarshal([]byte(v.Arguments), &args); err != nil { + return nil, fmt.Errorf("failed to parse function call arguments: %w", err) + } + } + msg := api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{{ + ID: v.CallID, + Function: api.ToolCallFunction{ + Name: v.Name, + Arguments: args, + }, + }}, + } + // Attach pending thinking + if pendingThinking != "" { + msg.Thinking = pendingThinking + pendingThinking = "" + } + messages = append(messages, msg) + case ResponsesFunctionCallOutput: + messages = append(messages, api.Message{ + Role: "tool", + Content: v.Output, + ToolCallID: v.CallID, + }) + } + } + + // If there's trailing reasoning without a following message, emit it + if pendingThinking != "" { + messages = append(messages, api.Message{ + Role: "assistant", + Thinking: pendingThinking, + }) + } + + options := make(map[string]any) + + if r.Temperature != nil { + options["temperature"] = *r.Temperature + } else { + options["temperature"] = 1.0 + } + + if r.TopP != nil { + options["top_p"] = *r.TopP + } else { //nolint:staticcheck // SA9003: empty branch + // TODO(drifkin): OpenAI defaults to 1.0 here, but we don't follow that here + // in case the model has a different default. It would be best if we + // understood whether there was a model-specific default and if not, we + // should also default to 1.0, but that will require some additional + // plumbing + } + + if r.MaxOutputTokens != nil { + options["num_predict"] = *r.MaxOutputTokens + } + + // Convert tools from Responses API format to api.Tool format + var tools []api.Tool + for _, t := range r.Tools { + tool, err := convertTool(t) + if err != nil { + return nil, err + } + tools = append(tools, tool) + } + + // Handle text format (e.g. json_schema) + var format json.RawMessage + if r.Text != nil && r.Text.Format != nil { + switch r.Text.Format.Type { + case "json_schema": + if r.Text.Format.Schema != nil { + format = r.Text.Format.Schema + } + } + } + + return &api.ChatRequest{ + Model: r.Model, + Messages: messages, + Options: options, + Tools: tools, + Format: format, + }, nil +} + +func convertTool(t ResponsesTool) (api.Tool, error) { + // Convert parameters from map[string]any to api.ToolFunctionParameters + var params api.ToolFunctionParameters + if t.Parameters != nil { + // Marshal and unmarshal to convert + b, err := json.Marshal(t.Parameters) + if err != nil { + return api.Tool{}, fmt.Errorf("failed to marshal tool parameters: %w", err) + } + if err := json.Unmarshal(b, ¶ms); err != nil { + return api.Tool{}, fmt.Errorf("failed to unmarshal tool parameters: %w", err) + } + } + + return api.Tool{ + Type: t.Type, + Function: api.ToolFunction{ + Name: t.Name, + Description: t.Description, + Parameters: params, + }, + }, nil +} + +func convertInputMessage(m ResponsesInputMessage) (api.Message, error) { + var content string + var images []api.ImageData + + for _, c := range m.Content { + switch v := c.(type) { + case ResponsesTextContent: + content += v.Text + case ResponsesOutputTextContent: + content += v.Text + case ResponsesImageContent: + if v.ImageURL == "" { + continue // Skip if no URL (FileID not supported) + } + img, err := decodeImageURL(v.ImageURL) + if err != nil { + return api.Message{}, err + } + images = append(images, img) + } + } + + return api.Message{ + Role: m.Role, + Content: content, + Images: images, + }, nil +} + +// Response types for the Responses API + +type ResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Status string `json:"status"` + Model string `json:"model"` + Output []ResponsesOutputItem `json:"output"` + Usage *ResponsesUsage `json:"usage,omitempty"` + // TODO(drifkin): add `temperature` and `top_p` to the response, but this + // requires additional plumbing to find the effective values since the + // defaults can come from the model or the request +} + +type ResponsesOutputItem struct { + ID string `json:"id"` + Type string `json:"type"` // "message", "function_call", or "reasoning" + Status string `json:"status,omitempty"` + Role string `json:"role,omitempty"` // for message + Content []ResponsesOutputContent `json:"content,omitempty"` // for message + CallID string `json:"call_id,omitempty"` // for function_call + Name string `json:"name,omitempty"` // for function_call + Arguments string `json:"arguments,omitempty"` // for function_call + + // Reasoning fields + Summary []ResponsesReasoningSummary `json:"summary,omitempty"` // for reasoning + EncryptedContent string `json:"encrypted_content,omitempty"` // for reasoning +} + +type ResponsesReasoningSummary struct { + Type string `json:"type"` // "summary_text" + Text string `json:"text"` +} + +type ResponsesOutputContent struct { + Type string `json:"type"` // "output_text" + Text string `json:"text"` +} + +type ResponsesUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ToResponse converts an api.ChatResponse to a Responses API response +func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse { + var output []ResponsesOutputItem + + // Add reasoning item if thinking is present + if chatResponse.Message.Thinking != "" { + output = append(output, ResponsesOutputItem{ + ID: fmt.Sprintf("rs_%s", responseID), + Type: "reasoning", + Summary: []ResponsesReasoningSummary{ + { + Type: "summary_text", + Text: chatResponse.Message.Thinking, + }, + }, + EncryptedContent: chatResponse.Message.Thinking, // Plain text for now + }) + } + + if len(chatResponse.Message.ToolCalls) > 0 { + toolCalls := ToToolCalls(chatResponse.Message.ToolCalls) + for i, tc := range toolCalls { + output = append(output, ResponsesOutputItem{ + ID: fmt.Sprintf("fc_%s_%d", responseID, i), + Type: "function_call", + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + } else { + output = append(output, ResponsesOutputItem{ + ID: itemID, + Type: "message", + Status: "completed", + Role: "assistant", + Content: []ResponsesOutputContent{ + { + Type: "output_text", + Text: chatResponse.Message.Content, + }, + }, + }) + } + + return ResponsesResponse{ + ID: responseID, + Object: "response", + CreatedAt: chatResponse.CreatedAt.Unix(), + Status: "completed", + Model: model, + Output: output, + Usage: &ResponsesUsage{ + InputTokens: chatResponse.PromptEvalCount, + OutputTokens: chatResponse.EvalCount, + TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount, + }, + } +} + +// Streaming events: + +// ResponsesStreamEvent represents a single Server-Sent Event for the Responses API. +type ResponsesStreamEvent struct { + Event string // The event type (e.g., "response.created") + Data any // The event payload (will be JSON-marshaled) +} + +// ResponsesStreamConverter converts api.ChatResponse objects to Responses API +// streaming events. It maintains state across multiple calls to handle the +// streaming event sequence correctly. +type ResponsesStreamConverter struct { + // Configuration (immutable after creation) + responseID string + itemID string + model string + + // State tracking (mutated across Process calls) + firstWrite bool + outputIndex int + contentIndex int + contentStarted bool + toolCallsSent bool + accumulatedText string + sequenceNumber int + + // Reasoning/thinking state + accumulatedThinking string + reasoningItemID string + reasoningStarted bool + reasoningDone bool + + // Tool calls state (for final output) + toolCallItems []map[string]any +} + +// newEvent creates a ResponsesStreamEvent with the sequence number included in the data. +func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]any) ResponsesStreamEvent { + data["type"] = eventType + data["sequence_number"] = c.sequenceNumber + c.sequenceNumber++ + return ResponsesStreamEvent{ + Event: eventType, + Data: data, + } +} + +// NewResponsesStreamConverter creates a new converter with the given configuration. +func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter { + return &ResponsesStreamConverter{ + responseID: responseID, + itemID: itemID, + model: model, + firstWrite: true, + } +} + +// Process takes a ChatResponse and returns the events that should be emitted. +// Events are returned in order. The caller is responsible for serializing +// and sending these events. +func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + hasToolCalls := len(r.Message.ToolCalls) > 0 + hasThinking := r.Message.Thinking != "" + + // First chunk - emit initial events + if c.firstWrite { + c.firstWrite = false + events = append(events, c.createResponseCreatedEvent()) + events = append(events, c.createResponseInProgressEvent()) + } + + // Handle reasoning/thinking (before other content) + if hasThinking { + events = append(events, c.processThinking(r.Message.Thinking)...) + } + + // Handle tool calls + if hasToolCalls { + events = append(events, c.processToolCalls(r.Message.ToolCalls)...) + c.toolCallsSent = true + } + + // Handle text content (only if no tool calls) + if !hasToolCalls && !c.toolCallsSent && r.Message.Content != "" { + events = append(events, c.processTextContent(r.Message.Content)...) + } + + // Done - emit closing events + if r.Done { + events = append(events, c.processCompletion(r)...) + } + + return events +} + +func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent { + return c.newEvent("response.created", map[string]any{ + "response": map[string]any{ + "id": c.responseID, + "object": "response", + "status": "in_progress", + "output": []any{}, + }, + }) +} + +func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent { + return c.newEvent("response.in_progress", map[string]any{ + "response": map[string]any{ + "id": c.responseID, + "object": "response", + "status": "in_progress", + "output": []any{}, + }, + }) +} + +func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + // Start reasoning item if not started + if !c.reasoningStarted { + c.reasoningStarted = true + c.reasoningItemID = fmt.Sprintf("rs_%d", rand.Intn(999999)) + + events = append(events, c.newEvent("response.output_item.added", map[string]any{ + "output_index": c.outputIndex, + "item": map[string]any{ + "id": c.reasoningItemID, + "type": "reasoning", + "summary": []any{}, + }, + })) + } + + // Accumulate thinking + c.accumulatedThinking += thinking + + // Emit delta + events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{ + "item_id": c.reasoningItemID, + "output_index": c.outputIndex, + "delta": thinking, + })) + + // TODO(drifkin): consider adding + // [`response.reasoning_text.delta`](https://platform.openai.com/docs/api-reference/responses-streaming/response/reasoning_text/delta), + // but need to do additional research to understand how it's used and how + // widely supported it is + + return events +} + +func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent { + if !c.reasoningStarted || c.reasoningDone { + return nil + } + c.reasoningDone = true + + events := []ResponsesStreamEvent{ + c.newEvent("response.reasoning_summary_text.done", map[string]any{ + "item_id": c.reasoningItemID, + "output_index": c.outputIndex, + "text": c.accumulatedThinking, + }), + c.newEvent("response.output_item.done", map[string]any{ + "output_index": c.outputIndex, + "item": map[string]any{ + "id": c.reasoningItemID, + "type": "reasoning", + "summary": []map[string]any{{"type": "summary_text", "text": c.accumulatedThinking}}, + "encrypted_content": c.accumulatedThinking, // Plain text for now + }, + }), + } + + c.outputIndex++ + return events +} + +func (c *ResponsesStreamConverter) processToolCalls(toolCalls []api.ToolCall) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + // Finish reasoning first if it was started + events = append(events, c.finishReasoning()...) + + converted := ToToolCalls(toolCalls) + + for i, tc := range converted { + fcItemID := fmt.Sprintf("fc_%d_%d", rand.Intn(999999), i) + + // Store for final output (with status: completed) + toolCallItem := map[string]any{ + "id": fcItemID, + "type": "function_call", + "status": "completed", + "call_id": tc.ID, + "name": tc.Function.Name, + "arguments": tc.Function.Arguments, + } + c.toolCallItems = append(c.toolCallItems, toolCallItem) + + // response.output_item.added for function call + events = append(events, c.newEvent("response.output_item.added", map[string]any{ + "output_index": c.outputIndex + i, + "item": map[string]any{ + "id": fcItemID, + "type": "function_call", + "status": "in_progress", + "call_id": tc.ID, + "name": tc.Function.Name, + "arguments": "", + }, + })) + + // response.function_call_arguments.delta + if tc.Function.Arguments != "" { + events = append(events, c.newEvent("response.function_call_arguments.delta", map[string]any{ + "item_id": fcItemID, + "output_index": c.outputIndex + i, + "delta": tc.Function.Arguments, + })) + } + + // response.function_call_arguments.done + events = append(events, c.newEvent("response.function_call_arguments.done", map[string]any{ + "item_id": fcItemID, + "output_index": c.outputIndex + i, + "arguments": tc.Function.Arguments, + })) + + // response.output_item.done for function call + events = append(events, c.newEvent("response.output_item.done", map[string]any{ + "output_index": c.outputIndex + i, + "item": map[string]any{ + "id": fcItemID, + "type": "function_call", + "status": "completed", + "call_id": tc.ID, + "name": tc.Function.Name, + "arguments": tc.Function.Arguments, + }, + })) + } + + return events +} + +func (c *ResponsesStreamConverter) processTextContent(content string) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + // Finish reasoning first if it was started + events = append(events, c.finishReasoning()...) + + // Emit output item and content part for first text content + if !c.contentStarted { + c.contentStarted = true + + // response.output_item.added + events = append(events, c.newEvent("response.output_item.added", map[string]any{ + "output_index": c.outputIndex, + "item": map[string]any{ + "id": c.itemID, + "type": "message", + "status": "in_progress", + "role": "assistant", + "content": []any{}, + }, + })) + + // response.content_part.added + events = append(events, c.newEvent("response.content_part.added", map[string]any{ + "item_id": c.itemID, + "output_index": c.outputIndex, + "content_index": c.contentIndex, + "part": map[string]any{ + "type": "output_text", + "text": "", + }, + })) + } + + // Accumulate text + c.accumulatedText += content + + // Emit content delta + events = append(events, c.newEvent("response.output_text.delta", map[string]any{ + "item_id": c.itemID, + "output_index": c.outputIndex, + "content_index": 0, + "delta": content, + })) + + return events +} + +func (c *ResponsesStreamConverter) buildFinalOutput() []any { + var output []any + + // Add reasoning item if present + if c.reasoningStarted { + output = append(output, map[string]any{ + "id": c.reasoningItemID, + "type": "reasoning", + "summary": []map[string]any{{"type": "summary_text", "text": c.accumulatedThinking}}, + "encrypted_content": c.accumulatedThinking, + }) + } + + // Add tool calls if present + if len(c.toolCallItems) > 0 { + for _, item := range c.toolCallItems { + output = append(output, item) + } + } else if c.contentStarted { + // Add message item if we had text content + output = append(output, map[string]any{ + "id": c.itemID, + "type": "message", + "status": "completed", + "role": "assistant", + "content": []map[string]any{{ + "type": "output_text", + "text": c.accumulatedText, + }}, + }) + } + + return output +} + +func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + // Finish reasoning if not done + events = append(events, c.finishReasoning()...) + + // Emit text completion events if we had text content + if !c.toolCallsSent && c.contentStarted { + // response.output_text.done + events = append(events, c.newEvent("response.output_text.done", map[string]any{ + "item_id": c.itemID, + "output_index": c.outputIndex, + "content_index": 0, + "text": c.accumulatedText, + })) + + // response.content_part.done + events = append(events, c.newEvent("response.content_part.done", map[string]any{ + "item_id": c.itemID, + "output_index": c.outputIndex, + "content_index": 0, + "part": map[string]any{ + "type": "output_text", + "text": c.accumulatedText, + }, + })) + + // response.output_item.done + events = append(events, c.newEvent("response.output_item.done", map[string]any{ + "output_index": c.outputIndex, + "item": map[string]any{ + "id": c.itemID, + "type": "message", + "status": "completed", + "role": "assistant", + "content": []map[string]any{{ + "type": "output_text", + "text": c.accumulatedText, + }}, + }, + })) + } + + // response.completed + events = append(events, c.newEvent("response.completed", map[string]any{ + "response": map[string]any{ + "id": c.responseID, + "object": "response", + "status": "completed", + "output": c.buildFinalOutput(), + "usage": map[string]any{ + "input_tokens": r.PromptEvalCount, + "output_tokens": r.EvalCount, + "total_tokens": r.PromptEvalCount + r.EvalCount, + }, + }, + })) + + return events +} diff --git a/openai/responses_test.go b/openai/responses_test.go new file mode 100644 index 00000000..50fbfdc5 --- /dev/null +++ b/openai/responses_test.go @@ -0,0 +1,1543 @@ +package openai + +import ( + "encoding/json" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestResponsesInputMessage_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want ResponsesInputMessage + wantErr bool + }{ + { + name: "text content", + json: `{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}`, + want: ResponsesInputMessage{ + Type: "message", + Role: "user", + Content: []ResponsesContent{ResponsesTextContent{Type: "input_text", Text: "hello"}}, + }, + }, + { + name: "image content", + json: `{"type": "message", "role": "user", "content": [{"type": "input_image", "detail": "auto", "image_url": "https://example.com/img.png"}]}`, + want: ResponsesInputMessage{ + Type: "message", + Role: "user", + Content: []ResponsesContent{ResponsesImageContent{ + Type: "input_image", + Detail: "auto", + ImageURL: "https://example.com/img.png", + }}, + }, + }, + { + name: "multiple content items", + json: `{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}, {"type": "input_text", "text": "world"}]}`, + want: ResponsesInputMessage{ + Type: "message", + Role: "user", + Content: []ResponsesContent{ + ResponsesTextContent{Type: "input_text", Text: "hello"}, + ResponsesTextContent{Type: "input_text", Text: "world"}, + }, + }, + }, + { + name: "unknown content type", + json: `{"type": "message", "role": "user", "content": [{"type": "unknown"}]}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got ResponsesInputMessage + err := json.Unmarshal([]byte(tt.json), &got) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got.Type != tt.want.Type { + t.Errorf("Type = %q, want %q", got.Type, tt.want.Type) + } + + if got.Role != tt.want.Role { + t.Errorf("Role = %q, want %q", got.Role, tt.want.Role) + } + + if len(got.Content) != len(tt.want.Content) { + t.Fatalf("len(Content) = %d, want %d", len(got.Content), len(tt.want.Content)) + } + + for i := range tt.want.Content { + switch wantContent := tt.want.Content[i].(type) { + case ResponsesTextContent: + gotContent, ok := got.Content[i].(ResponsesTextContent) + if !ok { + t.Fatalf("Content[%d] type = %T, want ResponsesTextContent", i, got.Content[i]) + } + if gotContent != wantContent { + t.Errorf("Content[%d] = %+v, want %+v", i, gotContent, wantContent) + } + case ResponsesImageContent: + gotContent, ok := got.Content[i].(ResponsesImageContent) + if !ok { + t.Fatalf("Content[%d] type = %T, want ResponsesImageContent", i, got.Content[i]) + } + if gotContent != wantContent { + t.Errorf("Content[%d] = %+v, want %+v", i, gotContent, wantContent) + } + } + } + }) + } +} + +func TestResponsesInput_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + wantText string + wantItems int + wantErr bool + }{ + { + name: "plain string", + json: `"hello world"`, + wantText: "hello world", + }, + { + name: "array with one message", + json: `[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}]`, + wantItems: 1, + }, + { + name: "array with multiple messages", + json: `[{"type": "message", "role": "system", "content": [{"type": "input_text", "text": "you are helpful"}]}, {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}]`, + wantItems: 2, + }, + { + name: "invalid input", + json: `123`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got ResponsesInput + err := json.Unmarshal([]byte(tt.json), &got) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got.Text != tt.wantText { + t.Errorf("Text = %q, want %q", got.Text, tt.wantText) + } + + if len(got.Items) != tt.wantItems { + t.Errorf("len(Items) = %d, want %d", len(got.Items), tt.wantItems) + } + }) + } +} + +func TestUnmarshalResponsesInputItem(t *testing.T) { + t.Run("message item", func(t *testing.T) { + got, err := unmarshalResponsesInputItem([]byte(`{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg, ok := got.(ResponsesInputMessage) + if !ok { + t.Fatalf("got type %T, want ResponsesInputMessage", got) + } + + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + }) + + t.Run("function_call item", func(t *testing.T) { + got, err := unmarshalResponsesInputItem([]byte(`{"type": "function_call", "call_id": "call_abc123", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + fc, ok := got.(ResponsesFunctionCall) + if !ok { + t.Fatalf("got type %T, want ResponsesFunctionCall", got) + } + + if fc.Type != "function_call" { + t.Errorf("Type = %q, want %q", fc.Type, "function_call") + } + if fc.CallID != "call_abc123" { + t.Errorf("CallID = %q, want %q", fc.CallID, "call_abc123") + } + if fc.Name != "get_weather" { + t.Errorf("Name = %q, want %q", fc.Name, "get_weather") + } + }) + + t.Run("function_call_output item", func(t *testing.T) { + got, err := unmarshalResponsesInputItem([]byte(`{"type": "function_call_output", "call_id": "call_abc123", "output": "the result"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + output, ok := got.(ResponsesFunctionCallOutput) + if !ok { + t.Fatalf("got type %T, want ResponsesFunctionCallOutput", got) + } + + if output.Type != "function_call_output" { + t.Errorf("Type = %q, want %q", output.Type, "function_call_output") + } + if output.CallID != "call_abc123" { + t.Errorf("CallID = %q, want %q", output.CallID, "call_abc123") + } + if output.Output != "the result" { + t.Errorf("Output = %q, want %q", output.Output, "the result") + } + }) + + t.Run("unknown item type", func(t *testing.T) { + _, err := unmarshalResponsesInputItem([]byte(`{"type": "unknown_type"}`)) + if err == nil { + t.Error("expected error, got nil") + } + }) +} + +func TestResponsesRequest_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + check func(t *testing.T, req ResponsesRequest) + wantErr bool + }{ + { + name: "simple string input", + json: `{"model": "gpt-oss:20b", "input": "hello"}`, + check: func(t *testing.T, req ResponsesRequest) { + if req.Model != "gpt-oss:20b" { + t.Errorf("Model = %q, want %q", req.Model, "gpt-oss:20b") + } + if req.Input.Text != "hello" { + t.Errorf("Input.Text = %q, want %q", req.Input.Text, "hello") + } + }, + }, + { + name: "array input with messages", + json: `{"model": "gpt-oss:20b", "input": [{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}]}`, + check: func(t *testing.T, req ResponsesRequest) { + if len(req.Input.Items) != 1 { + t.Fatalf("len(Input.Items) = %d, want 1", len(req.Input.Items)) + } + msg, ok := req.Input.Items[0].(ResponsesInputMessage) + if !ok { + t.Fatalf("Input.Items[0] type = %T, want ResponsesInputMessage", req.Input.Items[0]) + } + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + }, + }, + { + name: "with temperature", + json: `{"model": "gpt-oss:20b", "input": "hello", "temperature": 0.5}`, + check: func(t *testing.T, req ResponsesRequest) { + if req.Temperature == nil || *req.Temperature != 0.5 { + t.Errorf("Temperature = %v, want 0.5", req.Temperature) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got ResponsesRequest + err := json.Unmarshal([]byte(tt.json), &got) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if tt.check != nil { + tt.check(t, got) + } + }) + } +} + +func TestFromResponsesRequest_Tools(t *testing.T) { + reqJSON := `{ + "model": "gpt-oss:20b", + "input": "hello", + "tools": [ + { + "type": "function", + "name": "shell", + "description": "Runs a shell command", + "strict": false, + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "array", + "items": {"type": "string"}, + "description": "The command to execute" + } + }, + "required": ["command"] + } + } + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + // Check that tools were parsed + if len(req.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(req.Tools)) + } + + if req.Tools[0].Name != "shell" { + t.Errorf("expected tool name 'shell', got %q", req.Tools[0].Name) + } + + // Convert and check + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Tools) != 1 { + t.Fatalf("expected 1 converted tool, got %d", len(chatReq.Tools)) + } + + tool := chatReq.Tools[0] + if tool.Type != "function" { + t.Errorf("expected tool type 'function', got %q", tool.Type) + } + if tool.Function.Name != "shell" { + t.Errorf("expected function name 'shell', got %q", tool.Function.Name) + } + if tool.Function.Description != "Runs a shell command" { + t.Errorf("expected function description 'Runs a shell command', got %q", tool.Function.Description) + } + if tool.Function.Parameters.Type != "object" { + t.Errorf("expected parameters type 'object', got %q", tool.Function.Parameters.Type) + } + if len(tool.Function.Parameters.Required) != 1 || tool.Function.Parameters.Required[0] != "command" { + t.Errorf("expected required ['command'], got %v", tool.Function.Parameters.Required) + } +} + +func TestFromResponsesRequest_FunctionCallOutput(t *testing.T) { + // Test a complete tool call round-trip: + // 1. User message asking about weather + // 2. Assistant's function call (from previous response) + // 3. Function call output (the tool result) + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what is the weather?"}]}, + {"type": "function_call", "call_id": "call_abc123", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}, + {"type": "function_call_output", "call_id": "call_abc123", "output": "sunny, 72F"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + // Check that input items were parsed + if len(req.Input.Items) != 3 { + t.Fatalf("expected 3 input items, got %d", len(req.Input.Items)) + } + + // Verify the function_call item + fc, ok := req.Input.Items[1].(ResponsesFunctionCall) + if !ok { + t.Fatalf("Input.Items[1] type = %T, want ResponsesFunctionCall", req.Input.Items[1]) + } + if fc.Name != "get_weather" { + t.Errorf("Name = %q, want %q", fc.Name, "get_weather") + } + + // Verify the function_call_output item + fcOutput, ok := req.Input.Items[2].(ResponsesFunctionCallOutput) + if !ok { + t.Fatalf("Input.Items[2] type = %T, want ResponsesFunctionCallOutput", req.Input.Items[2]) + } + if fcOutput.CallID != "call_abc123" { + t.Errorf("CallID = %q, want %q", fcOutput.CallID, "call_abc123") + } + + // Convert and check + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + // Check the user message + userMsg := chatReq.Messages[0] + if userMsg.Role != "user" { + t.Errorf("expected role 'user', got %q", userMsg.Role) + } + + // Check the assistant message with tool call + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", assistantMsg.Role) + } + if len(assistantMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(assistantMsg.ToolCalls)) + } + if assistantMsg.ToolCalls[0].ID != "call_abc123" { + t.Errorf("expected tool call ID 'call_abc123', got %q", assistantMsg.ToolCalls[0].ID) + } + if assistantMsg.ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("expected function name 'get_weather', got %q", assistantMsg.ToolCalls[0].Function.Name) + } + + // Check the tool response message + toolMsg := chatReq.Messages[2] + if toolMsg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", toolMsg.Role) + } + if toolMsg.Content != "sunny, 72F" { + t.Errorf("expected content 'sunny, 72F', got %q", toolMsg.Content) + } + if toolMsg.ToolCallID != "call_abc123" { + t.Errorf("expected ToolCallID 'call_abc123', got %q", toolMsg.ToolCallID) + } +} + +func TestDecodeImageURL(t *testing.T) { + // Valid PNG base64 (1x1 red pixel) + validPNG := "" + + t.Run("valid png", func(t *testing.T) { + img, err := decodeImageURL(validPNG) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(img) == 0 { + t.Error("expected non-empty image data") + } + }) + + t.Run("valid jpeg", func(t *testing.T) { + // Just test the prefix validation with minimal base64 + _, err := decodeImageURL("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("blank mime type", func(t *testing.T) { + _, err := decodeImageURL("data:;base64,dGVzdA==") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("invalid mime type", func(t *testing.T) { + _, err := decodeImageURL("") + if err == nil { + t.Error("expected error for unsupported mime type") + } + }) + + t.Run("invalid base64", func(t *testing.T) { + _, err := decodeImageURL("-valid-base64!") + if err == nil { + t.Error("expected error for invalid base64") + } + }) + + t.Run("not a data url", func(t *testing.T) { + _, err := decodeImageURL("https://example.com/image.png") + if err == nil { + t.Error("expected error for non-data URL") + } + }) +} + +func TestFromResponsesRequest_Images(t *testing.T) { + // 1x1 red PNG pixel + pngBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + reqJSON := `{ + "model": "llava", + "input": [ + {"type": "message", "role": "user", "content": [ + {"type": "input_text", "text": "What is in this image?"}, + {"type": "input_image", "detail": "auto", "image_url": "data:image/png;base64,` + pngBase64 + `"} + ]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(chatReq.Messages)) + } + + msg := chatReq.Messages[0] + if msg.Role != "user" { + t.Errorf("expected role 'user', got %q", msg.Role) + } + if msg.Content != "What is in this image?" { + t.Errorf("expected content 'What is in this image?', got %q", msg.Content) + } + if len(msg.Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(msg.Images)) + } + if len(msg.Images[0]) == 0 { + t.Error("expected non-empty image data") + } +} + +func TestResponsesStreamConverter_TextOnly(t *testing.T) { + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + // First chunk with content + events := converter.Process(api.ChatResponse{ + Message: api.Message{ + Content: "Hello", + }, + }) + + // Should have: response.created, response.in_progress, output_item.added, content_part.added, output_text.delta + if len(events) != 5 { + t.Fatalf("expected 5 events, got %d", len(events)) + } + + if events[0].Event != "response.created" { + t.Errorf("events[0].Event = %q, want %q", events[0].Event, "response.created") + } + if events[1].Event != "response.in_progress" { + t.Errorf("events[1].Event = %q, want %q", events[1].Event, "response.in_progress") + } + if events[2].Event != "response.output_item.added" { + t.Errorf("events[2].Event = %q, want %q", events[2].Event, "response.output_item.added") + } + if events[3].Event != "response.content_part.added" { + t.Errorf("events[3].Event = %q, want %q", events[3].Event, "response.content_part.added") + } + if events[4].Event != "response.output_text.delta" { + t.Errorf("events[4].Event = %q, want %q", events[4].Event, "response.output_text.delta") + } + + // Second chunk with more content + events = converter.Process(api.ChatResponse{ + Message: api.Message{ + Content: " World", + }, + }) + + // Should only have output_text.delta (no more created/in_progress/added) + if len(events) != 1 { + t.Fatalf("expected 1 event, got %d", len(events)) + } + if events[0].Event != "response.output_text.delta" { + t.Errorf("events[0].Event = %q, want %q", events[0].Event, "response.output_text.delta") + } + + // Final chunk + events = converter.Process(api.ChatResponse{ + Message: api.Message{}, + Done: true, + }) + + // Should have: output_text.done, content_part.done, output_item.done, response.completed + if len(events) != 4 { + t.Fatalf("expected 4 events, got %d", len(events)) + } + if events[0].Event != "response.output_text.done" { + t.Errorf("events[0].Event = %q, want %q", events[0].Event, "response.output_text.done") + } + // Check that accumulated text is present + data := events[0].Data.(map[string]any) + if data["text"] != "Hello World" { + t.Errorf("accumulated text = %q, want %q", data["text"], "Hello World") + } +} + +func TestResponsesStreamConverter_ToolCalls(t *testing.T) { + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + events := converter.Process(api.ChatResponse{ + Message: api.Message{ + ToolCalls: []api.ToolCall{ + { + ID: "call_abc", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + }) + + // Should have: created, in_progress, output_item.added, arguments.delta, arguments.done, output_item.done + if len(events) != 6 { + t.Fatalf("expected 6 events, got %d", len(events)) + } + + if events[2].Event != "response.output_item.added" { + t.Errorf("events[2].Event = %q, want %q", events[2].Event, "response.output_item.added") + } + if events[3].Event != "response.function_call_arguments.delta" { + t.Errorf("events[3].Event = %q, want %q", events[3].Event, "response.function_call_arguments.delta") + } + if events[4].Event != "response.function_call_arguments.done" { + t.Errorf("events[4].Event = %q, want %q", events[4].Event, "response.function_call_arguments.done") + } + if events[5].Event != "response.output_item.done" { + t.Errorf("events[5].Event = %q, want %q", events[5].Event, "response.output_item.done") + } +} + +func TestResponsesStreamConverter_Reasoning(t *testing.T) { + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + // First chunk with thinking + events := converter.Process(api.ChatResponse{ + Message: api.Message{ + Thinking: "Let me think...", + }, + }) + + // Should have: created, in_progress, output_item.added (reasoning), reasoning_summary_text.delta + if len(events) != 4 { + t.Fatalf("expected 4 events, got %d", len(events)) + } + + if events[2].Event != "response.output_item.added" { + t.Errorf("events[2].Event = %q, want %q", events[2].Event, "response.output_item.added") + } + // Check it's a reasoning item + data := events[2].Data.(map[string]any) + item := data["item"].(map[string]any) + if item["type"] != "reasoning" { + t.Errorf("item type = %q, want %q", item["type"], "reasoning") + } + + if events[3].Event != "response.reasoning_summary_text.delta" { + t.Errorf("events[3].Event = %q, want %q", events[3].Event, "response.reasoning_summary_text.delta") + } + + // Second chunk with text content (reasoning should close first) + events = converter.Process(api.ChatResponse{ + Message: api.Message{ + Content: "The answer is 42", + }, + }) + + // Should have: reasoning_summary_text.done, output_item.done (reasoning), output_item.added (message), content_part.added, output_text.delta + if len(events) != 5 { + t.Fatalf("expected 5 events, got %d", len(events)) + } + + if events[0].Event != "response.reasoning_summary_text.done" { + t.Errorf("events[0].Event = %q, want %q", events[0].Event, "response.reasoning_summary_text.done") + } + if events[1].Event != "response.output_item.done" { + t.Errorf("events[1].Event = %q, want %q", events[1].Event, "response.output_item.done") + } + // Check the reasoning done item has encrypted_content + doneData := events[1].Data.(map[string]any) + doneItem := doneData["item"].(map[string]any) + if doneItem["encrypted_content"] != "Let me think..." { + t.Errorf("encrypted_content = %q, want %q", doneItem["encrypted_content"], "Let me think...") + } +} + +func TestFromResponsesRequest_ReasoningMerge(t *testing.T) { + t.Run("reasoning merged with following message", func(t *testing.T) { + reqJSON := `{ + "model": "qwen3", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "solve 2+2"}]}, + {"type": "reasoning", "id": "rs_123", "encrypted_content": "Let me think about this math problem...", "summary": [{"type": "summary_text", "text": "Thinking about math"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "input_text", "text": "The answer is 4"}]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (with thinking merged) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Check user message + if chatReq.Messages[0].Role != "user" { + t.Errorf("Messages[0].Role = %q, want %q", chatReq.Messages[0].Role, "user") + } + + // Check assistant message has both content and thinking + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want %q", assistantMsg.Role, "assistant") + } + if assistantMsg.Content != "The answer is 4" { + t.Errorf("Messages[1].Content = %q, want %q", assistantMsg.Content, "The answer is 4") + } + if assistantMsg.Thinking != "Let me think about this math problem..." { + t.Errorf("Messages[1].Thinking = %q, want %q", assistantMsg.Thinking, "Let me think about this math problem...") + } + }) + + t.Run("reasoning merged with following function call", func(t *testing.T) { + reqJSON := `{ + "model": "qwen3", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what is the weather?"}]}, + {"type": "reasoning", "id": "rs_123", "encrypted_content": "I need to call a tool for this...", "summary": []}, + {"type": "function_call", "call_id": "call_abc", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (with thinking + tool call) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Check assistant message has both tool call and thinking + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want %q", assistantMsg.Role, "assistant") + } + if assistantMsg.Thinking != "I need to call a tool for this..." { + t.Errorf("Messages[1].Thinking = %q, want %q", assistantMsg.Thinking, "I need to call a tool for this...") + } + if len(assistantMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(assistantMsg.ToolCalls)) + } + if assistantMsg.ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("ToolCalls[0].Function.Name = %q, want %q", assistantMsg.ToolCalls[0].Function.Name, "get_weather") + } + }) + + t.Run("multi-turn conversation with reasoning", func(t *testing.T) { + // Simulates: user asks -> model thinks + responds -> user follows up + reqJSON := `{ + "model": "qwen3", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "What is 2+2?"}]}, + {"type": "reasoning", "id": "rs_001", "encrypted_content": "This is a simple arithmetic problem. 2+2=4.", "summary": [{"type": "summary_text", "text": "Calculating 2+2"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "input_text", "text": "The answer is 4."}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Now multiply that by 3"}]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 3 messages: + // 1. user: "What is 2+2?" + // 2. assistant: thinking + "The answer is 4." + // 3. user: "Now multiply that by 3" + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + // Check first user message + if chatReq.Messages[0].Role != "user" || chatReq.Messages[0].Content != "What is 2+2?" { + t.Errorf("Messages[0] = {Role: %q, Content: %q}, want {Role: \"user\", Content: \"What is 2+2?\"}", + chatReq.Messages[0].Role, chatReq.Messages[0].Content) + } + + // Check assistant message has merged thinking + content + if chatReq.Messages[1].Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want \"assistant\"", chatReq.Messages[1].Role) + } + if chatReq.Messages[1].Content != "The answer is 4." { + t.Errorf("Messages[1].Content = %q, want \"The answer is 4.\"", chatReq.Messages[1].Content) + } + if chatReq.Messages[1].Thinking != "This is a simple arithmetic problem. 2+2=4." { + t.Errorf("Messages[1].Thinking = %q, want \"This is a simple arithmetic problem. 2+2=4.\"", + chatReq.Messages[1].Thinking) + } + + // Check second user message + if chatReq.Messages[2].Role != "user" || chatReq.Messages[2].Content != "Now multiply that by 3" { + t.Errorf("Messages[2] = {Role: %q, Content: %q}, want {Role: \"user\", Content: \"Now multiply that by 3\"}", + chatReq.Messages[2].Role, chatReq.Messages[2].Content) + } + }) + + t.Run("multi-turn with tool calls and reasoning", func(t *testing.T) { + // Simulates: user asks -> model thinks + calls tool -> tool responds -> model thinks + responds -> user follows up + reqJSON := `{ + "model": "qwen3", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "What is the weather in Paris?"}]}, + {"type": "reasoning", "id": "rs_001", "encrypted_content": "I need to call the weather API for Paris.", "summary": []}, + {"type": "function_call", "call_id": "call_abc", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}, + {"type": "function_call_output", "call_id": "call_abc", "output": "Sunny, 72°F"}, + {"type": "reasoning", "id": "rs_002", "encrypted_content": "The weather API returned sunny and 72°F. I should format this nicely.", "summary": []}, + {"type": "message", "role": "assistant", "content": [{"type": "input_text", "text": "It's sunny and 72°F in Paris!"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "What about London?"}]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 5 messages: + // 1. user: "What is the weather in Paris?" + // 2. assistant: thinking + tool call + // 3. tool: "Sunny, 72°F" + // 4. assistant: thinking + "It's sunny and 72°F in Paris!" + // 5. user: "What about London?" + if len(chatReq.Messages) != 5 { + t.Fatalf("expected 5 messages, got %d", len(chatReq.Messages)) + } + + // Message 1: user + if chatReq.Messages[0].Role != "user" { + t.Errorf("Messages[0].Role = %q, want \"user\"", chatReq.Messages[0].Role) + } + + // Message 2: assistant with thinking + tool call + if chatReq.Messages[1].Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want \"assistant\"", chatReq.Messages[1].Role) + } + if chatReq.Messages[1].Thinking != "I need to call the weather API for Paris." { + t.Errorf("Messages[1].Thinking = %q, want \"I need to call the weather API for Paris.\"", chatReq.Messages[1].Thinking) + } + if len(chatReq.Messages[1].ToolCalls) != 1 || chatReq.Messages[1].ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("Messages[1].ToolCalls not as expected") + } + + // Message 3: tool response + if chatReq.Messages[2].Role != "tool" || chatReq.Messages[2].Content != "Sunny, 72°F" { + t.Errorf("Messages[2] = {Role: %q, Content: %q}, want {Role: \"tool\", Content: \"Sunny, 72°F\"}", + chatReq.Messages[2].Role, chatReq.Messages[2].Content) + } + + // Message 4: assistant with thinking + content + if chatReq.Messages[3].Role != "assistant" { + t.Errorf("Messages[3].Role = %q, want \"assistant\"", chatReq.Messages[3].Role) + } + if chatReq.Messages[3].Thinking != "The weather API returned sunny and 72°F. I should format this nicely." { + t.Errorf("Messages[3].Thinking = %q, want correct thinking", chatReq.Messages[3].Thinking) + } + if chatReq.Messages[3].Content != "It's sunny and 72°F in Paris!" { + t.Errorf("Messages[3].Content = %q, want \"It's sunny and 72°F in Paris!\"", chatReq.Messages[3].Content) + } + + // Message 5: user follow-up + if chatReq.Messages[4].Role != "user" || chatReq.Messages[4].Content != "What about London?" { + t.Errorf("Messages[4] = {Role: %q, Content: %q}, want {Role: \"user\", Content: \"What about London?\"}", + chatReq.Messages[4].Role, chatReq.Messages[4].Content) + } + }) + + t.Run("trailing reasoning creates separate message", func(t *testing.T) { + reqJSON := `{ + "model": "qwen3", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "think about this"}]}, + {"type": "reasoning", "id": "rs_123", "encrypted_content": "Still thinking...", "summary": []} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (thinking only) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Check assistant message has only thinking + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want %q", assistantMsg.Role, "assistant") + } + if assistantMsg.Thinking != "Still thinking..." { + t.Errorf("Messages[1].Thinking = %q, want %q", assistantMsg.Thinking, "Still thinking...") + } + if assistantMsg.Content != "" { + t.Errorf("Messages[1].Content = %q, want empty", assistantMsg.Content) + } + }) +} + +func TestToResponse_WithReasoning(t *testing.T) { + response := ToResponse("gpt-oss:20b", "resp_123", "msg_456", api.ChatResponse{ + CreatedAt: time.Now(), + Message: api.Message{ + Thinking: "Analyzing the question...", + Content: "The answer is 42", + }, + Done: true, + }) + + // Should have 2 output items: reasoning + message + if len(response.Output) != 2 { + t.Fatalf("expected 2 output items, got %d", len(response.Output)) + } + + // First item should be reasoning + if response.Output[0].Type != "reasoning" { + t.Errorf("Output[0].Type = %q, want %q", response.Output[0].Type, "reasoning") + } + if len(response.Output[0].Summary) != 1 { + t.Fatalf("expected 1 summary item, got %d", len(response.Output[0].Summary)) + } + if response.Output[0].Summary[0].Text != "Analyzing the question..." { + t.Errorf("Summary[0].Text = %q, want %q", response.Output[0].Summary[0].Text, "Analyzing the question...") + } + if response.Output[0].EncryptedContent != "Analyzing the question..." { + t.Errorf("EncryptedContent = %q, want %q", response.Output[0].EncryptedContent, "Analyzing the question...") + } + + // Second item should be message + if response.Output[1].Type != "message" { + t.Errorf("Output[1].Type = %q, want %q", response.Output[1].Type, "message") + } + if response.Output[1].Content[0].Text != "The answer is 42" { + t.Errorf("Content[0].Text = %q, want %q", response.Output[1].Content[0].Text, "The answer is 42") + } +} + +func TestFromResponsesRequest_Instructions(t *testing.T) { + reqJSON := `{ + "model": "gpt-oss:20b", + "instructions": "You are a helpful pirate. Always respond in pirate speak.", + "input": "Hello" + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: system (instructions) + user + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // First message should be system with instructions + if chatReq.Messages[0].Role != "system" { + t.Errorf("Messages[0].Role = %q, want %q", chatReq.Messages[0].Role, "system") + } + if chatReq.Messages[0].Content != "You are a helpful pirate. Always respond in pirate speak." { + t.Errorf("Messages[0].Content = %q, want instructions", chatReq.Messages[0].Content) + } + + // Second message should be user + if chatReq.Messages[1].Role != "user" { + t.Errorf("Messages[1].Role = %q, want %q", chatReq.Messages[1].Role, "user") + } + if chatReq.Messages[1].Content != "Hello" { + t.Errorf("Messages[1].Content = %q, want %q", chatReq.Messages[1].Content, "Hello") + } +} + +func TestFromResponsesRequest_MaxOutputTokens(t *testing.T) { + reqJSON := `{ + "model": "gpt-oss:20b", + "input": "Write a story", + "max_output_tokens": 100 + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Check that num_predict is set in options + numPredict, ok := chatReq.Options["num_predict"] + if !ok { + t.Fatal("expected num_predict in options") + } + if numPredict != 100 { + t.Errorf("num_predict = %v, want 100", numPredict) + } +} + +func TestFromResponsesRequest_TextFormatJsonSchema(t *testing.T) { + reqJSON := `{ + "model": "gpt-oss:20b", + "input": "Give me info about John who is 30", + "text": { + "format": { + "type": "json_schema", + "name": "person", + "strict": true, + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + } + } + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + // Verify the text format was parsed + if req.Text == nil || req.Text.Format == nil { + t.Fatal("expected Text.Format to be set") + } + if req.Text.Format.Type != "json_schema" { + t.Errorf("Text.Format.Type = %q, want %q", req.Text.Format.Type, "json_schema") + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Check that Format is set + if chatReq.Format == nil { + t.Fatal("expected Format to be set") + } + + // Verify the schema is passed through + var schema map[string]any + if err := json.Unmarshal(chatReq.Format, &schema); err != nil { + t.Fatalf("failed to unmarshal format: %v", err) + } + if schema["type"] != "object" { + t.Errorf("schema type = %v, want %q", schema["type"], "object") + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties in schema") + } + if _, ok := props["name"]; !ok { + t.Error("expected 'name' in schema properties") + } + if _, ok := props["age"]; !ok { + t.Error("expected 'age' in schema properties") + } +} + +func TestFromResponsesRequest_TextFormatText(t *testing.T) { + // When format type is "text", Format should be nil (no constraint) + reqJSON := `{ + "model": "gpt-oss:20b", + "input": "Hello", + "text": { + "format": { + "type": "text" + } + } + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Format should be nil for "text" type + if chatReq.Format != nil { + t.Errorf("expected Format to be nil for text type, got %s", string(chatReq.Format)) + } +} + +func TestResponsesInputMessage_ShorthandFormats(t *testing.T) { + t.Run("string content shorthand", func(t *testing.T) { + // Content can be a plain string instead of an array of content items + jsonStr := `{"type": "message", "role": "user", "content": "Hello world"}` + + var msg ResponsesInputMessage + if err := json.Unmarshal([]byte(jsonStr), &msg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + + textContent, ok := msg.Content[0].(ResponsesTextContent) + if !ok { + t.Fatalf("Content[0] type = %T, want ResponsesTextContent", msg.Content[0]) + } + if textContent.Text != "Hello world" { + t.Errorf("Content[0].Text = %q, want %q", textContent.Text, "Hello world") + } + if textContent.Type != "input_text" { + t.Errorf("Content[0].Type = %q, want %q", textContent.Type, "input_text") + } + }) + + t.Run("output_text content type", func(t *testing.T) { + // Previous assistant responses come back with output_text content type + jsonStr := `{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I am an assistant"}]}` + + var msg ResponsesInputMessage + if err := json.Unmarshal([]byte(jsonStr), &msg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if msg.Role != "assistant" { + t.Errorf("Role = %q, want %q", msg.Role, "assistant") + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + + outputContent, ok := msg.Content[0].(ResponsesOutputTextContent) + if !ok { + t.Fatalf("Content[0] type = %T, want ResponsesOutputTextContent", msg.Content[0]) + } + if outputContent.Text != "I am an assistant" { + t.Errorf("Content[0].Text = %q, want %q", outputContent.Text, "I am an assistant") + } + }) +} + +func TestUnmarshalResponsesInputItem_ShorthandMessage(t *testing.T) { + t.Run("message without type field", func(t *testing.T) { + // When type is omitted but role is present, treat as message + jsonStr := `{"role": "user", "content": "Hello"}` + + item, err := unmarshalResponsesInputItem([]byte(jsonStr)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg, ok := item.(ResponsesInputMessage) + if !ok { + t.Fatalf("got type %T, want ResponsesInputMessage", item) + } + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + }) + + t.Run("message with both type and role", func(t *testing.T) { + // Explicit type should still work + jsonStr := `{"type": "message", "role": "system", "content": "You are helpful"}` + + item, err := unmarshalResponsesInputItem([]byte(jsonStr)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg, ok := item.(ResponsesInputMessage) + if !ok { + t.Fatalf("got type %T, want ResponsesInputMessage", item) + } + if msg.Role != "system" { + t.Errorf("Role = %q, want %q", msg.Role, "system") + } + }) +} + +func TestFromResponsesRequest_ShorthandFormats(t *testing.T) { + t.Run("shorthand message without type", func(t *testing.T) { + // Real-world format from OpenAI SDK + reqJSON := `{ + "model": "gpt-4.1", + "input": [ + {"role": "user", "content": "What is the weather in Tokyo?"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + if len(req.Input.Items) != 1 { + t.Fatalf("expected 1 input item, got %d", len(req.Input.Items)) + } + + msg, ok := req.Input.Items[0].(ResponsesInputMessage) + if !ok { + t.Fatalf("Input.Items[0] type = %T, want ResponsesInputMessage", req.Input.Items[0]) + } + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(chatReq.Messages)) + } + if chatReq.Messages[0].Content != "What is the weather in Tokyo?" { + t.Errorf("Content = %q, want %q", chatReq.Messages[0].Content, "What is the weather in Tokyo?") + } + }) + + t.Run("conversation with output_text from previous response", func(t *testing.T) { + // Simulates a multi-turn conversation where previous assistant response is sent back + reqJSON := `{ + "model": "gpt-4.1", + "input": [ + {"role": "user", "content": "Hello"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hi there!"}]}, + {"role": "user", "content": "How are you?"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + // Check first user message + if chatReq.Messages[0].Role != "user" || chatReq.Messages[0].Content != "Hello" { + t.Errorf("Messages[0] = {Role: %q, Content: %q}, want {Role: \"user\", Content: \"Hello\"}", + chatReq.Messages[0].Role, chatReq.Messages[0].Content) + } + + // Check assistant message (output_text should be converted to content) + if chatReq.Messages[1].Role != "assistant" || chatReq.Messages[1].Content != "Hi there!" { + t.Errorf("Messages[1] = {Role: %q, Content: %q}, want {Role: \"assistant\", Content: \"Hi there!\"}", + chatReq.Messages[1].Role, chatReq.Messages[1].Content) + } + + // Check second user message + if chatReq.Messages[2].Role != "user" || chatReq.Messages[2].Content != "How are you?" { + t.Errorf("Messages[2] = {Role: %q, Content: %q}, want {Role: \"user\", Content: \"How are you?\"}", + chatReq.Messages[2].Role, chatReq.Messages[2].Content) + } + }) +} + +func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) { + // Verify that response.output_item.done includes content field for messages + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + // First chunk + converter.Process(api.ChatResponse{ + Message: api.Message{Content: "Hello World"}, + }) + + // Final chunk + events := converter.Process(api.ChatResponse{ + Message: api.Message{}, + Done: true, + }) + + // Find the output_item.done event + var outputItemDone map[string]any + for _, event := range events { + if event.Event == "response.output_item.done" { + outputItemDone = event.Data.(map[string]any) + break + } + } + + if outputItemDone == nil { + t.Fatal("expected response.output_item.done event") + } + + item := outputItemDone["item"].(map[string]any) + if item["type"] != "message" { + t.Errorf("item.type = %q, want %q", item["type"], "message") + } + + content, ok := item["content"].([]map[string]any) + if !ok { + t.Fatalf("item.content type = %T, want []map[string]any", item["content"]) + } + if len(content) != 1 { + t.Fatalf("len(content) = %d, want 1", len(content)) + } + if content[0]["type"] != "output_text" { + t.Errorf("content[0].type = %q, want %q", content[0]["type"], "output_text") + } + if content[0]["text"] != "Hello World" { + t.Errorf("content[0].text = %q, want %q", content[0]["text"], "Hello World") + } +} + +func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) { + // Verify that response.completed includes the output array + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + // Process some content + converter.Process(api.ChatResponse{ + Message: api.Message{Content: "Test response"}, + }) + + // Final chunk + events := converter.Process(api.ChatResponse{ + Message: api.Message{}, + Done: true, + }) + + // Find the response.completed event + var responseCompleted map[string]any + for _, event := range events { + if event.Event == "response.completed" { + responseCompleted = event.Data.(map[string]any) + break + } + } + + if responseCompleted == nil { + t.Fatal("expected response.completed event") + } + + response := responseCompleted["response"].(map[string]any) + output, ok := response["output"].([]any) + if !ok { + t.Fatalf("response.output type = %T, want []any", response["output"]) + } + + if len(output) != 1 { + t.Fatalf("len(output) = %d, want 1", len(output)) + } + + item := output[0].(map[string]any) + if item["type"] != "message" { + t.Errorf("output[0].type = %q, want %q", item["type"], "message") + } +} + +func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) { + // Verify that response.created includes an empty output array + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + events := converter.Process(api.ChatResponse{ + Message: api.Message{Content: "Hi"}, + }) + + // First event should be response.created + if events[0].Event != "response.created" { + t.Fatalf("events[0].Event = %q, want %q", events[0].Event, "response.created") + } + + data := events[0].Data.(map[string]any) + response := data["response"].(map[string]any) + + output, ok := response["output"].([]any) + if !ok { + t.Fatalf("response.output type = %T, want []any", response["output"]) + } + + // Should be empty array initially + if len(output) != 0 { + t.Errorf("len(output) = %d, want 0", len(output)) + } +} + +func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) { + // Verify that events include incrementing sequence numbers + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + events := converter.Process(api.ChatResponse{ + Message: api.Message{Content: "Hello"}, + }) + + for i, event := range events { + data := event.Data.(map[string]any) + seqNum, ok := data["sequence_number"].(int) + if !ok { + t.Fatalf("events[%d] missing sequence_number", i) + } + if seqNum != i { + t.Errorf("events[%d].sequence_number = %d, want %d", i, seqNum, i) + } + } + + // Process more content, sequence should continue + moreEvents := converter.Process(api.ChatResponse{ + Message: api.Message{Content: " World"}, + }) + + expectedSeq := len(events) + for i, event := range moreEvents { + data := event.Data.(map[string]any) + seqNum := data["sequence_number"].(int) + if seqNum != expectedSeq+i { + t.Errorf("moreEvents[%d].sequence_number = %d, want %d", i, seqNum, expectedSeq+i) + } + } +} + +func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) { + // Verify that function call items include status field + converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") + + events := converter.Process(api.ChatResponse{ + Message: api.Message{ + ToolCalls: []api.ToolCall{ + { + ID: "call_abc", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + }) + + // Find output_item.added event + var addedItem map[string]any + var doneItem map[string]any + for _, event := range events { + data := event.Data.(map[string]any) + if data["type"] == "response.output_item.added" { + item := data["item"].(map[string]any) + if item["type"] == "function_call" { + addedItem = item + } + } + if data["type"] == "response.output_item.done" { + item := data["item"].(map[string]any) + if item["type"] == "function_call" { + doneItem = item + } + } + } + + if addedItem == nil { + t.Fatal("expected function_call output_item.added event") + } + if addedItem["status"] != "in_progress" { + t.Errorf("output_item.added status = %q, want %q", addedItem["status"], "in_progress") + } + + if doneItem == nil { + t.Fatal("expected function_call output_item.done event") + } + if doneItem["status"] != "completed" { + t.Errorf("output_item.done status = %q, want %q", doneItem["status"], "completed") + } +} diff --git a/server/routes.go b/server/routes.go index bbf6b9b9..54f23d5d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) + r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) if rc != nil { // wrap old with new @@ -2393,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { } return msgs } + From 2dfb74410d2cca08fa6dd62a0863e2e8d5ad1a8a Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 11 Dec 2025 16:02:05 -0800 Subject: [PATCH 24/35] model: fix rotary embeddings for ministral 3 (#13432) --- model/models/mistral3/model_text.go | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 01eca1c5..36106107 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -29,24 +29,13 @@ type TextOptions struct { func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { var ropeOpts []func(*rope.Options) if o.ropeType == "yarn" { - getMscale := func(scale, mscale float64) float64 { - if scale <= 1.0 { - return 1.0 - } - return 0.1*mscale*math.Log(scale) + 1.0 - } - - var attnFactor float32 if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 { - attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim))) - } else { - attnFactor = float32(getMscale(float64(o.ropeScale), 1)) + ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0))) } ropeOpts = append(ropeOpts, rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings), rope.WithExtrapolationFactor(o.ropeExtrapolation), - rope.WithAttentionFactor(attnFactor), rope.WithBetaFast(o.ropeBetaFast), rope.WithBetaSlow(o.ropeBetaSlow), ) From 709f842457f40550c88da80f84bc8d7ba29371b9 Mon Sep 17 00:00:00 2001 From: JJ Date: Thu, 11 Dec 2025 16:08:57 -0800 Subject: [PATCH 25/35] Update README.md (#13373) Correct Markdown syntax for Swollama GitHub and DocC documentation links --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1f1560ca..bb08819d 100644 --- a/README.md +++ b/README.md @@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama. - [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples) - [Ollama for Swift](https://github.com/mattt/ollama-swift) -- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama) +- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama) - [GoLamify](https://github.com/prasad89/golamify) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API) From 93d45d7a0464e534c0211f2fb9d720b1855ce43a Mon Sep 17 00:00:00 2001 From: Alexander Gusak Date: Fri, 12 Dec 2025 00:14:45 +0000 Subject: [PATCH 26/35] docs: fix link to modelfile.mdx (#13220) --- api/client.go | 2 +- docs/api.md | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/client.go b/api/client.go index 9a8f89e4..c7051689 100644 --- a/api/client.go +++ b/api/client.go @@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error // Create creates a model from a [Modelfile]. fn is a progress function that // behaves similarly to other methods (see [Client.Pull]). // -// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md +// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { var resp ProgressResponse diff --git a/docs/api.md b/docs/api.md index 99ceaa11..03f6dbea 100644 --- a/docs/api.md +++ b/docs/api.md @@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin Advanced parameters (optional): - `format`: the format to return a response in. Format can be `json` or a JSON schema -- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `system`: system message to (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects @@ -507,7 +507,7 @@ The `message` object has the following fields: Advanced parameters (optional): - `format`: the format to return a response in. Format can be `json` or a JSON schema. -- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) @@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo - `template`: (optional) the prompt template for the model - `license`: (optional) a string or list of strings containing the license or licenses for the model - `system`: (optional) a string containing the system prompt for the model -- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters) +- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters) - `messages`: (optional) a list of message objects used to create a conversation - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects - `quantize` (optional): quantize a non-quantized (e.g. float16) model @@ -1698,7 +1698,7 @@ Generate embeddings from a model Advanced parameters: - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` -- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `dimensions`: number of dimensions for the embedding @@ -1817,7 +1817,7 @@ Generate embeddings from a model Advanced parameters: -- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) ### Examples From 9b2035d194d56227dd880d686410d8546e10563e Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 11 Dec 2025 17:30:12 -0800 Subject: [PATCH 27/35] openai: add tool call appending to previous assistant message (#13434) * openai: add tool call appending to previous asst message * add tests for thinking appending --- openai/responses.go | 39 +++-- openai/responses_test.go | 299 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 14 deletions(-) diff --git a/openai/responses.go b/openai/responses.go index 8f6b1d94..1fbd2de0 100644 --- a/openai/responses.go +++ b/openai/responses.go @@ -365,22 +365,33 @@ func FromResponsesRequest(r ResponsesRequest) (*api.ChatRequest, error) { return nil, fmt.Errorf("failed to parse function call arguments: %w", err) } } - msg := api.Message{ - Role: "assistant", - ToolCalls: []api.ToolCall{{ - ID: v.CallID, - Function: api.ToolCallFunction{ - Name: v.Name, - Arguments: args, - }, - }}, + toolCall := api.ToolCall{ + ID: v.CallID, + Function: api.ToolCallFunction{ + Name: v.Name, + Arguments: args, + }, } - // Attach pending thinking - if pendingThinking != "" { - msg.Thinking = pendingThinking - pendingThinking = "" + + // Merge tool call into existing assistant message if it has content or tool calls + if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" { + lastMsg := &messages[len(messages)-1] + lastMsg.ToolCalls = append(lastMsg.ToolCalls, toolCall) + if pendingThinking != "" { + lastMsg.Thinking = pendingThinking + pendingThinking = "" + } + } else { + msg := api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{toolCall}, + } + if pendingThinking != "" { + msg.Thinking = pendingThinking + pendingThinking = "" + } + messages = append(messages, msg) } - messages = append(messages, msg) case ResponsesFunctionCallOutput: messages = append(messages, api.Message{ Role: "tool", diff --git a/openai/responses_test.go b/openai/responses_test.go index 50fbfdc5..86731e72 100644 --- a/openai/responses_test.go +++ b/openai/responses_test.go @@ -456,6 +456,305 @@ func TestFromResponsesRequest_FunctionCallOutput(t *testing.T) { } } +func TestFromResponsesRequest_FunctionCallMerge(t *testing.T) { + t.Run("function call merges with preceding assistant message", func(t *testing.T) { + // When assistant message has content followed by function_call, + // they should be merged into a single message + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what is the weather?"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I'll check the weather for you."}]}, + {"type": "function_call", "call_id": "call_abc123", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (with content + tool call merged) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Check user message + if chatReq.Messages[0].Role != "user" { + t.Errorf("Messages[0].Role = %q, want %q", chatReq.Messages[0].Role, "user") + } + + // Check assistant message has both content and tool call + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want %q", assistantMsg.Role, "assistant") + } + if assistantMsg.Content != "I'll check the weather for you." { + t.Errorf("Messages[1].Content = %q, want %q", assistantMsg.Content, "I'll check the weather for you.") + } + if len(assistantMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(assistantMsg.ToolCalls)) + } + if assistantMsg.ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("ToolCalls[0].Function.Name = %q, want %q", assistantMsg.ToolCalls[0].Function.Name, "get_weather") + } + }) + + t.Run("function call without preceding assistant creates new message", func(t *testing.T) { + // When there's no preceding assistant message, function_call creates its own message + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what is the weather?"}]}, + {"type": "function_call", "call_id": "call_abc123", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (tool call only) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Check assistant message has tool call but no content + assistantMsg := chatReq.Messages[1] + if assistantMsg.Role != "assistant" { + t.Errorf("Messages[1].Role = %q, want %q", assistantMsg.Role, "assistant") + } + if assistantMsg.Content != "" { + t.Errorf("Messages[1].Content = %q, want empty", assistantMsg.Content) + } + if len(assistantMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(assistantMsg.ToolCalls)) + } + }) + + t.Run("multiple function calls merge into same assistant message", func(t *testing.T) { + // Multiple consecutive function_calls should all merge into the same assistant message + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "check weather and time"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I'll check both."}]}, + {"type": "function_call", "call_id": "call_1", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}, + {"type": "function_call", "call_id": "call_2", "name": "get_time", "arguments": "{\"city\":\"Paris\"}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (content + both tool calls) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + // Assistant has content + both tool calls + assistantMsg := chatReq.Messages[1] + if assistantMsg.Content != "I'll check both." { + t.Errorf("Messages[1].Content = %q, want %q", assistantMsg.Content, "I'll check both.") + } + if len(assistantMsg.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls, got %d", len(assistantMsg.ToolCalls)) + } + if assistantMsg.ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("ToolCalls[0].Function.Name = %q, want %q", assistantMsg.ToolCalls[0].Function.Name, "get_weather") + } + if assistantMsg.ToolCalls[1].Function.Name != "get_time" { + t.Errorf("ToolCalls[1].Function.Name = %q, want %q", assistantMsg.ToolCalls[1].Function.Name, "get_time") + } + }) + + t.Run("new assistant message starts fresh tool call group", func(t *testing.T) { + // assistant → tool_call → tool_call → assistant → tool_call + // Should result in 2 assistant messages with their respective tool calls + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "do multiple things"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "First batch."}]}, + {"type": "function_call", "call_id": "call_1", "name": "func_a", "arguments": "{}"}, + {"type": "function_call", "call_id": "call_2", "name": "func_b", "arguments": "{}"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Second batch."}]}, + {"type": "function_call", "call_id": "call_3", "name": "func_c", "arguments": "{}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 3 messages: + // 1. user + // 2. assistant "First batch." + tool calls [func_a, func_b] + // 3. assistant "Second batch." + tool calls [func_c] + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + asst1 := chatReq.Messages[1] + if asst1.Content != "First batch." { + t.Errorf("Messages[1].Content = %q, want %q", asst1.Content, "First batch.") + } + if len(asst1.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls in Messages[1], got %d", len(asst1.ToolCalls)) + } + if asst1.ToolCalls[0].Function.Name != "func_a" { + t.Errorf("Messages[1].ToolCalls[0] = %q, want %q", asst1.ToolCalls[0].Function.Name, "func_a") + } + if asst1.ToolCalls[1].Function.Name != "func_b" { + t.Errorf("Messages[1].ToolCalls[1] = %q, want %q", asst1.ToolCalls[1].Function.Name, "func_b") + } + + asst2 := chatReq.Messages[2] + if asst2.Content != "Second batch." { + t.Errorf("Messages[2].Content = %q, want %q", asst2.Content, "Second batch.") + } + if len(asst2.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call in Messages[2], got %d", len(asst2.ToolCalls)) + } + if asst2.ToolCalls[0].Function.Name != "func_c" { + t.Errorf("Messages[2].ToolCalls[0] = %q, want %q", asst2.ToolCalls[0].Function.Name, "func_c") + } + }) + + t.Run("function call merges with assistant that has thinking", func(t *testing.T) { + // reasoning → assistant (gets thinking) → function_call → should merge + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "think and act"}]}, + {"type": "reasoning", "id": "rs_1", "encrypted_content": "Let me think...", "summary": []}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I thought about it."}]}, + {"type": "function_call", "call_id": "call_1", "name": "do_thing", "arguments": "{}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 2 messages: user and assistant (thinking + content + tool call) + if len(chatReq.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(chatReq.Messages)) + } + + asst := chatReq.Messages[1] + if asst.Thinking != "Let me think..." { + t.Errorf("Messages[1].Thinking = %q, want %q", asst.Thinking, "Let me think...") + } + if asst.Content != "I thought about it." { + t.Errorf("Messages[1].Content = %q, want %q", asst.Content, "I thought about it.") + } + if len(asst.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(asst.ToolCalls)) + } + if asst.ToolCalls[0].Function.Name != "do_thing" { + t.Errorf("ToolCalls[0].Function.Name = %q, want %q", asst.ToolCalls[0].Function.Name, "do_thing") + } + }) + + t.Run("mixed thinking and content with multiple tool calls", func(t *testing.T) { + // Test: + // 1. reasoning → assistant (empty content, gets thinking) → tc (merges) + // 2. assistant with content → tc → tc (both merge) + // Result: 2 assistant messages + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "complex task"}]}, + {"type": "reasoning", "id": "rs_1", "encrypted_content": "Thinking first...", "summary": []}, + {"type": "message", "role": "assistant", "content": ""}, + {"type": "function_call", "call_id": "call_1", "name": "think_action", "arguments": "{}"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Now doing more."}]}, + {"type": "function_call", "call_id": "call_2", "name": "action_a", "arguments": "{}"}, + {"type": "function_call", "call_id": "call_3", "name": "action_b", "arguments": "{}"} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + // Should have 3 messages: + // 1. user + // 2. assistant with thinking + tool call [think_action] + // 3. assistant with content "Now doing more." + tool calls [action_a, action_b] + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + // First assistant: thinking + tool call + asst1 := chatReq.Messages[1] + if asst1.Thinking != "Thinking first..." { + t.Errorf("Messages[1].Thinking = %q, want %q", asst1.Thinking, "Thinking first...") + } + if asst1.Content != "" { + t.Errorf("Messages[1].Content = %q, want empty", asst1.Content) + } + if len(asst1.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call in Messages[1], got %d", len(asst1.ToolCalls)) + } + if asst1.ToolCalls[0].Function.Name != "think_action" { + t.Errorf("Messages[1].ToolCalls[0] = %q, want %q", asst1.ToolCalls[0].Function.Name, "think_action") + } + + // Second assistant: content + 2 tool calls + asst2 := chatReq.Messages[2] + if asst2.Content != "Now doing more." { + t.Errorf("Messages[2].Content = %q, want %q", asst2.Content, "Now doing more.") + } + if len(asst2.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls in Messages[2], got %d", len(asst2.ToolCalls)) + } + if asst2.ToolCalls[0].Function.Name != "action_a" { + t.Errorf("Messages[2].ToolCalls[0] = %q, want %q", asst2.ToolCalls[0].Function.Name, "action_a") + } + if asst2.ToolCalls[1].Function.Name != "action_b" { + t.Errorf("Messages[2].ToolCalls[1] = %q, want %q", asst2.ToolCalls[1].Function.Name, "action_b") + } + }) +} + func TestDecodeImageURL(t *testing.T) { // Valid PNG base64 (1x1 red pixel) validPNG := "" From 9f7822851c1f080d7d2a1dbe0e4d51233e5a28bc Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 11 Dec 2025 17:39:40 -0800 Subject: [PATCH 28/35] docs: add docs for v1/responses and rework openai compat section (#13416) * docs: add docs for v1/responses and rework openai compat section I reworked the examples to be separated by topic and to be fully runnable (i.e., they now log output instead of just suggesting how a call might be made). We now use ``s so that each example has a dropdown on the docs site for users to choose, which makes the examples a lot more digestible (since you only see approx 1/3 of the code you used to). I also added a new tool to extract code examples into files so that it's easier to actually run them and check that they work. ## Example ```shell go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx ``` Output: ``` Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 - 01_basic.py - 01_basic.js - 01_basic.sh - 02_responses.py - 02_responses.js - 02_responses.sh - 03_vision.py - 03_vision.js - 03_vision.sh Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 To run examples: cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 npm install # for JS examples then run individual files with `node file.js`, `python file.py`, `bash file.sh` ``` In the future we should consider actually running the examples in CI and having some sort of acceptance test so we can automatically detect when our examples break. So this is just a start in that direction. * Update docs/api/openai-compatibility.mdx Co-authored-by: Parth Sareen * Update docs/api/openai-compatibility.mdx Co-authored-by: Parth Sareen --------- Co-authored-by: Parth Sareen --- docs/api/openai-compatibility.mdx | 288 +++++++++++++------------- docs/tools/extract-examples/README.md | 46 ++++ docs/tools/extract-examples/main.go | 137 ++++++++++++ 3 files changed, 322 insertions(+), 149 deletions(-) create mode 100644 docs/tools/extract-examples/README.md create mode 100644 docs/tools/extract-examples/main.go diff --git a/docs/api/openai-compatibility.mdx b/docs/api/openai-compatibility.mdx index 8329934a..94febc30 100644 --- a/docs/api/openai-compatibility.mdx +++ b/docs/api/openai-compatibility.mdx @@ -6,16 +6,16 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op ## Usage -### OpenAI Python library +### Simple `v1/chat/completions` example -```python + + +```python basic.py from openai import OpenAI client = OpenAI( base_url='http://localhost:11434/v1/', - - # required but ignored - api_key='ollama', + api_key='ollama', # required but ignored ) chat_completion = client.chat.completions.create( @@ -25,96 +25,125 @@ chat_completion = client.chat.completions.create( 'content': 'Say this is a test', } ], - model='llama3.2', + model='gpt-oss:20b', +) +print(chat_completion.choices[0].message.content) +``` + +```javascript basic.js +import OpenAI from "openai"; + +const openai = new OpenAI({ + baseURL: "http://localhost:11434/v1/", + apiKey: "ollama", // required but ignored +}); + +const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: "user", content: "Say this is a test" }], + model: "gpt-oss:20b", +}); + +console.log(chatCompletion.choices[0].message.content); +``` + +```shell basic.sh +curl -X POST http://localhost:11434/v1/chat/completions \ +-H "Content-Type: application/json" \ +-d '{ + "model": "gpt-oss:20b", + "messages": [{ "role": "user", "content": "Say this is a test" }] +}' +``` + + + +### Simple `v1/responses` example + + + +```python responses.py +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + api_key='ollama', # required but ignored +) + +responses_result = client.responses.create( + model='qwen3:8b', + input='Write a short poem about the color blue', +) +print(responses_result.output_text) +``` + +```javascript responses.js +import OpenAI from "openai"; + +const openai = new OpenAI({ + baseURL: "http://localhost:11434/v1/", + apiKey: "ollama", // required but ignored +}); + +const responsesResult = await openai.responses.create({ + model: "qwen3:8b", + input: "Write a short poem about the color blue", +}); + +console.log(responsesResult.output_text); +``` + +```shell responses.sh +curl -X POST http://localhost:11434/v1/responses \ +-H "Content-Type: application/json" \ +-d '{ + "model": "qwen3:8b", + "input": "Write a short poem about the color blue" +}' +``` + + + +### v1/chat/completions with vision example + + + +```python vision.py +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + api_key='ollama', # required but ignored ) response = client.chat.completions.create( - model="llava", + model='qwen3-vl:8b', messages=[ { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, + 'role': 'user', + 'content': [ + {'type': 'text', 'text': "What's in this image?"}, { - "type": "image_url", - "image_url": "", + 'type': 'image_url', + 'image_url': '', }, ], } ], max_tokens=300, ) - -completion = client.completions.create( - model="llama3.2", - prompt="Say this is a test", -) - -list_completion = client.models.list() - -model = client.models.retrieve("llama3.2") - -embeddings = client.embeddings.create( - model="all-minilm", - input=["why is the sky blue?", "why is the grass green?"], -) +print(response.choices[0].message.content) ``` -#### Structured outputs - -```python -from pydantic import BaseModel -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama") - -# Define the schema for the response -class FriendInfo(BaseModel): - name: str - age: int - is_available: bool - -class FriendList(BaseModel): - friends: list[FriendInfo] - -try: - completion = client.beta.chat.completions.parse( - temperature=0, - model="llama3.1:8b", - messages=[ - {"role": "user", "content": "I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format"} - ], - response_format=FriendList, - ) - - friends_response = completion.choices[0].message - if friends_response.parsed: - print(friends_response.parsed) - elif friends_response.refusal: - print(friends_response.refusal) -except Exception as e: - print(f"Error: {e}") -``` - -### OpenAI JavaScript library - -```javascript +```javascript vision.js import OpenAI from "openai"; const openai = new OpenAI({ baseURL: "http://localhost:11434/v1/", - - // required but ignored - apiKey: "ollama", -}); - -const chatCompletion = await openai.chat.completions.create({ - messages: [{ role: "user", content: "Say this is a test" }], - model: "llama3.2", + apiKey: "ollama", // required but ignored }); const response = await openai.chat.completions.create({ - model: "llava", + model: "qwen3-vl:8b", messages: [ { role: "user", @@ -129,84 +158,20 @@ const response = await openai.chat.completions.create({ }, ], }); - -const completion = await openai.completions.create({ - model: "llama3.2", - prompt: "Say this is a test.", -}); - -const listCompletion = await openai.models.list(); - -const model = await openai.models.retrieve("llama3.2"); - -const embedding = await openai.embeddings.create({ - model: "all-minilm", - input: ["why is the sky blue?", "why is the grass green?"], -}); +console.log(response.choices[0].message.content); ``` -### `curl` - -```shell -curl http://localhost:11434/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "llama3.2", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello!" - } - ] - }' - -curl http://localhost:11434/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "llava", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What'\''s in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": "" - } - } - ] - } - ], - "max_tokens": 300 - }' - -curl http://localhost:11434/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "llama3.2", - "prompt": "Say this is a test" - }' - -curl http://localhost:11434/v1/models - -curl http://localhost:11434/v1/models/llama3.2 - -curl http://localhost:11434/v1/embeddings \ - -H "Content-Type: application/json" \ - -d '{ - "model": "all-minilm", - "input": ["why is the sky blue?", "why is the grass green?"] - }' +```shell vision.sh +curl -X POST http://localhost:11434/v1/chat/completions \ +-H "Content-Type: application/json" \ +-d '{ + "model": "qwen3-vl:8b", + "messages": [{ "role": "user", "content": [{"type": "text", "text": "What is this an image of?"}, {"type": "image_url", "image_url": ""}]}] +}' ``` + + ## Endpoints ### `/v1/chat/completions` @@ -310,6 +275,31 @@ curl http://localhost:11434/v1/embeddings \ - [x] `dimensions` - [ ] `user` +### `/v1/responses` + +Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support). + +#### Supported features + +- [x] Streaming +- [x] Tools (function calling) +- [x] Reasoning summaries (for thinking models) +- [ ] Stateful requests + +#### Supported request fields + +- [x] `model` +- [x] `input` +- [x] `instructions` +- [x] `tools` +- [x] `stream` +- [x] `temperature` +- [x] `top_p` +- [x] `max_output_tokens` +- [ ] `previous_response_id` (stateful v1/responses not supported) +- [ ] `conversation` (stateful v1/responses not supported) +- [ ] `truncation` + ## Models Before using a model, pull it locally `ollama pull`: @@ -365,4 +355,4 @@ curl http://localhost:11434/v1/chat/completions \ } ] }' -``` \ No newline at end of file +``` diff --git a/docs/tools/extract-examples/README.md b/docs/tools/extract-examples/README.md new file mode 100644 index 00000000..38560492 --- /dev/null +++ b/docs/tools/extract-examples/README.md @@ -0,0 +1,46 @@ +# extract-examples + +Extracts code examples from MDX files to a temp directory so you can run them. + +## Usage + +```shell +go run docs/tools/extract-examples/main.go +``` + +## Example + +```shell +go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx +``` + +Output: + +``` +Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 + + - 01_basic.py + - 01_basic.js + - 01_basic.sh + - 02_responses.py + - 02_responses.js + - 02_responses.sh + - 03_vision.py + - 03_vision.js + - 03_vision.sh + +Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 + +To run examples: + + cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368 + npm install # for JS examples + +then run individual files with `node file.js`, `python file.py`, `bash file.sh` +``` + +## How it works + +- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `) +- Groups examples by their `` and prefixes filenames with `01_`, `02_`, etc. +- Writes all extracted files to a temp directory diff --git a/docs/tools/extract-examples/main.go b/docs/tools/extract-examples/main.go new file mode 100644 index 00000000..3f09af5c --- /dev/null +++ b/docs/tools/extract-examples/main.go @@ -0,0 +1,137 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +func main() { + if len(os.Args) < 2 { + fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go ") + os.Exit(1) + } + + mdxFile := os.Args[1] + + f, err := os.Open(mdxFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + defer f.Close() + + // Create temp directory + tempDir, err := os.MkdirTemp("", "mdx-examples-*") + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Extracting code examples to: %s\n\n", tempDir) + + // Patterns + codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$") + codeGroupStart := regexp.MustCompile("^") + + scanner := bufio.NewScanner(f) + inCodeBlock := false + inCodeGroup := false + var currentFile string + var content strings.Builder + count := 0 + codeGroupNum := 0 + + for scanner.Scan() { + line := scanner.Text() + + // Track CodeGroup boundaries + if codeGroupStart.MatchString(line) { + inCodeGroup = true + codeGroupNum++ + continue + } + if codeGroupEnd.MatchString(line) { + inCodeGroup = false + continue + } + + if inCodeBlock { + if line == "```" { + // End of code block - write file + if currentFile != "" { + outPath := filepath.Join(tempDir, currentFile) + if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil { + fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err) + } else { + fmt.Printf(" - %s\n", currentFile) + count++ + } + } + inCodeBlock = false + currentFile = "" + content.Reset() + } else { + content.WriteString(line) + content.WriteString("\n") + } + } else { + if matches := codeBlockStart.FindStringSubmatch(line); matches != nil { + inCodeBlock = true + filename := matches[2] + // Prefix with CodeGroup number if inside a CodeGroup + if inCodeGroup { + currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename) + } else { + currentFile = filename + } + content.Reset() + } + } + } + + if err := scanner.Err(); err != nil { + fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err) + os.Exit(1) + } + + // Write package.json for JavaScript dependencies + packageJSON := `{ + "name": "mdx-examples", + "type": "module", + "dependencies": { + "openai": "^4", + "ollama": "^0.5" + } +} +` + if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil { + fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err) + } + + // Write pyproject.toml for Python dependencies + pyprojectTOML := `[project] +name = "mdx-examples" +version = "0.0.0" +dependencies = [ + "openai", + "ollama", +] +` + if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil { + fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err) + } + + fmt.Printf("\n") + fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir) + fmt.Printf("\n") + fmt.Printf("To run examples:\n") + fmt.Printf("\n") + fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir) + fmt.Printf("\n") + fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n") +} From 95fdd8d619ad4dc9215cdce8a8665284a96cd96f Mon Sep 17 00:00:00 2001 From: Eva H <63033505+hoyyeva@users.noreply.github.com> Date: Fri, 12 Dec 2025 11:09:37 -0500 Subject: [PATCH 29/35] fix: select and update models folder in settings (#13412) --- app/dialog/cocoa/dlg.m | 60 ++++++++++++++++++++++++------------------ app/server/server.go | 4 +-- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/app/dialog/cocoa/dlg.m b/app/dialog/cocoa/dlg.m index e90098e7..35851978 100644 --- a/app/dialog/cocoa/dlg.m +++ b/app/dialog/cocoa/dlg.m @@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) { } NSArray* urls = [panel URLs]; - if(self->params->allowMultiple && [urls count] >= 1) { + if([urls count] == 0) { + return DLG_CANCEL; + } + + if(self->params->allowMultiple) { // For multiple files, we need to return all paths separated by null bytes char* bufPtr = self->params->buf; int remainingBuf = self->params->nbuf; - // Calculate total required buffer size first - int totalSize = 0; - for(NSURL* url in urls) { - char tempBuf[PATH_MAX]; - if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) { - return DLG_URLFAIL; - } - totalSize += strlen(tempBuf) + 1; // +1 for null terminator - } - totalSize += 1; // Final null terminator + // Calculate total required buffer size first + int totalSize = 0; + for(NSURL* url in urls) { + char tempBuf[PATH_MAX]; + if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) { + return DLG_URLFAIL; + } + totalSize += strlen(tempBuf) + 1; // +1 for null terminator + } + totalSize += 1; // Final null terminator - if(totalSize > self->params->nbuf) { - // Not enough buffer space - return DLG_URLFAIL; - } + if(totalSize > self->params->nbuf) { + // Not enough buffer space + return DLG_URLFAIL; + } - // Now actually copy the paths (we know we have space) - bufPtr = self->params->buf; - for(NSURL* url in urls) { - char tempBuf[PATH_MAX]; - [url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]; - int pathLen = strlen(tempBuf); - strcpy(bufPtr, tempBuf); - bufPtr += pathLen + 1; - } - *bufPtr = '\0'; // Final null terminator + // Now actually copy the paths (we know we have space) + bufPtr = self->params->buf; + for(NSURL* url in urls) { + char tempBuf[PATH_MAX]; + [url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]; + int pathLen = strlen(tempBuf); + strcpy(bufPtr, tempBuf); + bufPtr += pathLen + 1; + } + *bufPtr = '\0'; // Final null terminator + } else { + // Single file/directory selection - write path to buffer + NSURL* url = [urls firstObject]; + if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) { + return DLG_URLFAIL; + } } return DLG_OK; diff --git a/app/server/server.go b/app/server/server.go index 64b96b1f..2e0c2d1e 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) { if _, err := os.Stat(settings.Models); err == nil { env["OLLAMA_MODELS"] = settings.Models } else { - slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err) - settings.Models = "" - s.store.SetSettings(settings) + slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err) } } if settings.ContextLength > 0 { From de9ecfd01c5fb22829d6daf090629da0849d0337 Mon Sep 17 00:00:00 2001 From: Eva H <63033505+hoyyeva@users.noreply.github.com> Date: Fri, 12 Dec 2025 11:43:35 -0500 Subject: [PATCH 30/35] tidy up lint warnings on windows (#13430) --- app/cmd/app/app_darwin.go | 7 ------- app/cmd/app/app_windows.go | 5 ----- app/dialog/dlgs_windows.go | 2 +- app/wintray/eventloop.go | 8 ++++---- 4 files changed, 5 insertions(+), 17 deletions(-) diff --git a/app/cmd/app/app_darwin.go b/app/cmd/app/app_darwin.go index 2018ce8e..8e886a12 100644 --- a/app/cmd/app/app_darwin.go +++ b/app/cmd/app/app_darwin.go @@ -191,13 +191,6 @@ func LaunchNewApp() { C.launchApp(appName) } -// Send a request to the main app thread to load a UI page -func sendUIRequestMessage(path string) { - p := C.CString(path) - defer C.free(unsafe.Pointer(p)) - C.uiRequest(p) -} - func registerLaunchAgent(hasCompletedFirstRun bool) { // Remove any stale Login Item registrations C.unregisterSelfFromLoginItem() diff --git a/app/cmd/app/app_windows.go b/app/cmd/app/app_windows.go index b563d409..9caeb178 100644 --- a/app/cmd/app/app_windows.go +++ b/app/cmd/app/app_windows.go @@ -263,11 +263,6 @@ func createLoginShortcut() error { return nil } -// Send a request to the main app thread to load a UI page -func sendUIRequestMessage(path string) { - wintray.SendUIRequestMessage(path) -} - func LaunchNewApp() { } diff --git a/app/dialog/dlgs_windows.go b/app/dialog/dlgs_windows.go index c5b175ca..51ba9ee6 100644 --- a/app/dialog/dlgs_windows.go +++ b/app/dialog/dlgs_windows.go @@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10 type WinDlgError int func (e WinDlgError) Error() string { - return fmt.Sprintf("CommDlgExtendedError: %#x", e) + return fmt.Sprintf("CommDlgExtendedError: %#x", int(e)) } func err() error { diff --git a/app/wintray/eventloop.go b/app/wintray/eventloop.go index 15fbd0c3..dda433e2 100644 --- a/app/wintray/eventloop.go +++ b/app/wintray/eventloop.go @@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui case uint32(UI_REQUEST_MSG_ID): // Requests for the UI must always come from the main event thread l := int(wParam) - path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) + path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec t.app.UIRun(path) case WM_COPYDATA: // Handle URL scheme requests from other instances if lParam != 0 { - cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) - if cds.DwData == 1 { // Our identifier for URL scheme messages + cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec + if cds.DwData == 1 { // Our identifier for URL scheme messages // Convert the data back to string data := make([]byte, cds.CbData) - copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) + copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec urlScheme := string(data) handleURLSchemeRequest(urlScheme) lResult = 1 // Return non-zero to indicate success From 773089515805bc32391f6662656824ae58d573a4 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 12 Dec 2025 11:48:43 -0800 Subject: [PATCH 31/35] Enable Ollama engine by default (#13443) This changes the default behavior to use the Ollama engine for supported models, while retaining the ability to disable the Ollama engine and fall back to the Llama engine. Models in the OllamaEngineRequired list will always run on the Ollama engine. --- envconfig/config.go | 4 ++-- llm/server.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 238e5e6e..c0b2e2f0 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -199,7 +199,7 @@ var ( // MultiUserCache optimizes prompt caching for multi-user scenarios MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") // Enable the new Ollama engine - NewEngine = Bool("OLLAMA_NEW_ENGINE") + NewEngine = BoolWithDefault("OLLAMA_NEW_ENGINE") // ContextLength sets the default context length ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) // Auth enables authentication between the Ollama client and server @@ -291,7 +291,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, - "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, + "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(true), "Enable the new Ollama engine"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational diff --git a/llm/server.go b/llm/server.go index 5c232f0f..abf6035d 100644 --- a/llm/server.go +++ b/llm/server.go @@ -143,7 +143,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st var llamaModel *llama.Model var textProcessor model.TextProcessor var err error - if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { + if envconfig.NewEngine(true) || f.KV().OllamaEngineRequired() { if len(projectors) == 0 { textProcessor, err = model.NewTextProcessor(modelPath) } else { From 3af5d3b73840b1a2a9c714194da4c32940cf3671 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 12 Dec 2025 13:27:08 -0800 Subject: [PATCH 32/35] model: force rope factor 1.0 for Gemma 3 (#13445) --- model/models/gemma3/model_text.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index f76fba74..37e688d6 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -90,12 +90,15 @@ func newTextModel(c fs.Config) *TextModel { // Google's Gemma 3 release with sliding window attention does // not use final logit softcapping, and so force it to 0.0 + // The QAT weights for Gemma 3 also included an incorrect + // value for the rope scale, so we need to set it to 1.0 here. // TODO (jmorganca): this should ideally be set to 0.0 in the // model configuration instead of here, as future versions of // models may include both sliding window attention and final // logit softcapping. if slices.Contains(m.TextConfig.slidingWindowPattern, true) { m.TextConfig.finalLogitSoftcap = 0.0 + m.TextConfig.ropeScale = 1.0 } if numBlocks == gemma27BLayerCount { From bd6c1d6b49aca86dbb1a59182b293c0d1f7b8db8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 12 Dec 2025 13:27:19 -0800 Subject: [PATCH 33/35] flash attn: add auto mode for llama engine (#13052) * flash attn: add auto mode for llama engine If the user does not specify fa in the environment, use auto-mode. * review comments * ensure kv cache quantized types have FA explicitly enabled additional review comments --- fs/ggml/ggml.go | 13 ++++++-- llama/llama.go | 13 +++++--- llm/server.go | 63 ++++++++++++++++++++++++++++-------- ml/backend.go | 2 +- ml/backend/ggml/ggml.go | 6 ++-- ml/device.go | 26 +++++++++++++++ runner/llamarunner/runner.go | 3 +- 7 files changed, 101 insertions(+), 25 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 4004bbfd..691ea32b 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -13,6 +13,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/util/bufioutil" + "github.com/ollama/ollama/ml" ) type GGML struct { @@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { }, nil } -func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { +func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) { context *= uint64(numParallel) embedding := f.KV().EmbeddingLength() @@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri } partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 - if useFlashAttention { + if useFlashAttention == ml.FlashAttentionEnabled { // rough estimate of graph size with flash attention on partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte } @@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool { return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } +// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type +func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool { + if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" { + return false + } + return true +} + // SupportsFlashAttention checks if the model supports flash attention func (f GGML) SupportsFlashAttention() bool { _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())] diff --git a/llama/llama.go b/llama/llama.go index 70bf3b9c..49b3f56a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -118,7 +118,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize * numSeqMax) @@ -127,10 +127,13 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads = C.int(threads) params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) - if flashAttention { - params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED - } else { - params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED + switch flashAttention { + case ml.FlashAttentionEnabled: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED) + case ml.FlashAttentionDisabled: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED) + case ml.FlashAttentionAuto: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO) } params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) diff --git a/llm/server.go b/llm/server.go index abf6035d..49af4e1b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st if len(projectors) > 0 && llamaModel != nil { loadRequest.ProjectorPath = projectors[0] } + // Determine if the user has forced FA on or off + faUserSet := false + if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) { + faUserSet = true + } fa := envconfig.FlashAttention(f.FlashAttention()) @@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st kvct := strings.ToLower(envconfig.KvCacheType()) - if fa { - slog.Info("enabling flash attention") - loadRequest.FlashAttention = true - - // Flash Attention also supports kv cache quantization - // Enable if the requested and kv cache type is supported by the model - if f.SupportsKVCacheType(kvct) { - loadRequest.KvCacheType = kvct - } else { - slog.Warn("kv cache type not supported by model", "type", kvct) + if textProcessor == nil { + flashAttention := ml.FlashAttentionAuto + if faUserSet { + if fa { + flashAttention = ml.FlashAttentionEnabled + } else { + flashAttention = ml.FlashAttentionDisabled + } + } + + if kvct != "" { + if f.KVCacheTypeIsQuantized(kvct) { + if flashAttention != ml.FlashAttentionEnabled { + slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct) + loadRequest.KvCacheType = "" + } else if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) + } + } else { + if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) + } + } + } + loadRequest.FlashAttention = flashAttention + } else { + // For Ollama engine, use our SupportsFlashAttention logic + if fa { + slog.Info("enabling flash attention") + loadRequest.FlashAttention = ml.FlashAttentionEnabled + + // Flash Attention also supports kv cache quantization + // Enable if the requested and kv cache type is supported by the model + if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("kv cache type not supported by model", "type", kvct) + } + } else if kvct != "" && kvct != "f16" { + slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) } - } else if kvct != "" && kvct != "f16" { - slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) } gpuLibs := ml.LibraryPaths(gpus) @@ -435,7 +472,7 @@ type LoadRequest struct { LoraPath []string Parallel int BatchSize int - FlashAttention bool + FlashAttention ml.FlashAttentionType KvSize int KvCacheType string NumThreads int diff --git a/ml/backend.go b/ml/backend.go index 6e5a059a..1e781fa7 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -74,7 +74,7 @@ type BackendParams struct { GPULayers GPULayersList // FlashAttention indicates that we should use a fused flash attention kernel - FlashAttention bool + FlashAttention FlashAttentionType } var backends = make(map[string]func(string, BackendParams) (Backend, error)) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 18bdc91e..a50d8ec9 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -109,7 +109,7 @@ type Backend struct { // btDeviceMemory maps from a buffer type to the memory allocations associated with that device btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory - flashAttention bool + flashAttention ml.FlashAttentionType // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler maxGraphNodes int @@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { } func (b *Backend) CacheConfig() ml.CacheConfig { - if b.flashAttention { + if b.flashAttention == ml.FlashAttentionEnabled { return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} } else { return ml.CacheConfig{CachePadding: 256, PermutedV: true} @@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin query := t.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) - if t.b.flashAttention { + if t.b.flashAttention == ml.FlashAttentionEnabled { value = value.Permute(ctx, 0, 2, 1, 3) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) diff --git a/ml/device.go b/ml/device.go index f892b512..47e180d3 100644 --- a/ml/device.go +++ b/ml/device.go @@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool { return true } +type FlashAttentionType int32 + +const ( + // Aligned with llama_flash_attn_type + FlashAttentionAuto FlashAttentionType = -1 + FlashAttentionDisabled FlashAttentionType = 0 + FlashAttentionEnabled FlashAttentionType = 1 +) + +func (f FlashAttentionType) LogValue() slog.Value { + return slog.AnyValue(f.String()) +} + +func (f FlashAttentionType) String() string { + switch f { + case FlashAttentionAuto: + return "Auto" + case FlashAttentionDisabled: + return "Disabled" + case FlashAttentionEnabled: + return "Enabled" + default: + return "unknown" + } +} + // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variables // Set mustFilter true to enable filtering of CUDA devices diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index cb4bbe50..de9d718b 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/llama" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/runner/common" ) @@ -832,7 +833,7 @@ func (s *Server) loadModel( ppath string, kvSize int, kvCacheType string, - flashAttention bool, + flashAttention ml.FlashAttentionType, threads int, multiUserCache bool, ) { From 1b308e1d2a478e70ef3e31e6b24d687a44b33016 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 12 Dec 2025 16:29:01 -0800 Subject: [PATCH 34/35] model: fix global layer rope scale values for gemma 3 (#13452) --- model/models/gemma3/model_text.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 37e688d6..759cc6b3 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -28,10 +28,10 @@ type TextConfig struct { finalLogitSoftcap float32 } -func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { +func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor { ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} if o.ropeType == "yarn" { - attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale)))) ropeOpts = append(ropeOpts, rope.WithOriginalContextLength(o.ropeOriginalContext), rope.WithExtrapolationFactor(o.ropeExtrapolation), @@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi ) } - return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...) + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...) } type TextModel struct { @@ -83,7 +83,7 @@ func newTextModel(c fs.Config) *TextModel { ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), - ropeScale: c.Float("rope.scaling.factor", 1.0), + ropeScale: c.Float("rope.scaling.factor", 8.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } @@ -117,31 +117,31 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (opts *TextConfig) ropeBaseForLayer(layer int) float32 { +func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) { if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] { - return opts.ropeLocalBase + return opts.ropeLocalBase, 1.0 } // Standard Gemma3: only every n-th layer is global, // where n = gemmaGlobalCacheCount, otherwise use // the local rope base if (layer+1)%gemmaGlobalCacheCount > 0 { - return opts.ropeLocalBase + return opts.ropeLocalBase, 1.0 } // default to global rope base - return opts.ropeBase + return opts.ropeBase, opts.ropeScale } func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeBase := opts.ropeBaseForLayer(layer) + ropeBase, ropeScale := opts.ropeValuesForLayer(layer) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -152,7 +152,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -165,7 +165,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil + ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer) + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil } type TextMLP struct { From 4ff8a691bcef296aa976e19d0ba9c7b74ae9f27c Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 12 Dec 2025 17:51:56 -0800 Subject: [PATCH 35/35] model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453) --- model/models/gemma3/model_text.go | 60 +++++++++++++++++-------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 759cc6b3..e1c0004d 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -2,7 +2,6 @@ package gemma3 import ( "math" - "slices" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -13,19 +12,20 @@ import ( ) type TextConfig struct { - hiddenSize, numHeads, numKVHeads int - attnKeyLen, attnValLen int - eps, ropeScale float32 - ropeLocalBase float32 - largeModelScaling bool - slidingWindowPattern []bool - ropeBase float32 - ropeType string - ropeOriginalContext int - ropeExtrapolation float32 - ropeBetaFast float32 - ropeBetaSlow float32 - finalLogitSoftcap float32 + hiddenSize, contextLength, numHeads, numKVHeads int + attnKeyLen, attnValLen int + eps, ropeScale float32 + ropeLocalBase float32 + largeModelScaling bool + slidingWindow uint32 + slidingWindowPattern []bool + ropeBase float32 + ropeType string + ropeOriginalContext int + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + finalLogitSoftcap float32 } func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor { @@ -55,6 +55,9 @@ type TextModel struct { const ( gemmaGlobalCacheCount = 6 + gemma1BLayerCount = 26 + gemma4BLayerCount = 34 + gemma12BLayerCount = 48 gemma27BLayerCount = 62 ) @@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel { Layers: make([]TextLayer, numBlocks), TextConfig: &TextConfig{ hiddenSize: int(c.Uint("embedding_length")), + contextLength: int(c.Uint("context_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), attnKeyLen: int(c.Uint("attention.key_length", 256)), @@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeBase: c.Float("rope.freq_base", 1000000.0), + slidingWindow: c.Uint("attention.sliding_window"), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), ropeType: c.String("rope.scaling.type"), ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")), ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), - ropeScale: c.Float("rope.scaling.factor", 8.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } - // Google's Gemma 3 release with sliding window attention does - // not use final logit softcapping, and so force it to 0.0 - // The QAT weights for Gemma 3 also included an incorrect - // value for the rope scale, so we need to set it to 1.0 here. - // TODO (jmorganca): this should ideally be set to 0.0 in the - // model configuration instead of here, as future versions of - // models may include both sliding window attention and final - // logit softcapping. - if slices.Contains(m.TextConfig.slidingWindowPattern, true) { - m.TextConfig.finalLogitSoftcap = 0.0 - m.TextConfig.ropeScale = 1.0 + // Apply corrections for older versions of the Gemma 3 models + // by looking at whether they use sliding window attention and + // based on their layer counts. + if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) { + switch numBlocks { + case gemma1BLayerCount: + // The 1B model has final logit softcapping set to 30.0 + // but it should be 0.0 + m.TextConfig.finalLogitSoftcap = 0.0 + case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount: + // The 4B, 12B, and 27B models have rope scale unset + // but it shuold be set to 8.0 + m.TextConfig.ropeScale = 8.0 + } } if numBlocks == gemma27BLayerCount {